Skip to content

Commit 12ebd24

Browse files
committed
feat(sidekick): Promote sample generation information from Rust annotations to model
1 parent 4f9cc64 commit 12ebd24

File tree

11 files changed

+357
-325
lines changed

11 files changed

+357
-325
lines changed

internal/sidekick/api/model.go

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1228,6 +1228,9 @@ type Enum struct {
12281228
// The unique integer values, some enums have multiple aliases for the
12291229
// same number (e.g. `enum X { a = 0, b = 0, c = 1 }`).
12301230
UniqueNumberValues []*EnumValue
1231+
// ValuesForExamples contains a subset of values suitable for use in generated samples.
1232+
// e.g. non-deprecated, non-zero values.
1233+
ValuesForExamples []*SampleValue
12311234
// Parent returns the ancestor of this node, if any.
12321235
Parent *Message
12331236
// The Protobuf package this enum belongs to.
@@ -1254,6 +1257,14 @@ type EnumValue struct {
12541257
Codec any
12551258
}
12561259

1260+
// SampleValue represents a value used in a sample.
1261+
type SampleValue struct {
1262+
// The enum value.
1263+
EnumValue *EnumValue
1264+
// The index of the value in the sample list (0-based).
1265+
Index int
1266+
}
1267+
12571268
// Field defines a field in a Message.
12581269
type Field struct {
12591270
// Documentation for the field.
@@ -1456,6 +1467,10 @@ type OneOf struct {
14561467
Documentation string
14571468
// Fields associated with the one-of.
14581469
Fields []*Field
1470+
// The best field to show in a oneof related samples.
1471+
// Non deprecated fields are preferred, then scalar, repeated, map fields
1472+
// in that order.
1473+
ExampleField *Field
14591474
// Codec is a placeholder to put language specific annotations.
14601475
Codec any
14611476
}

internal/sidekick/api/xref.go

Lines changed: 105 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,13 @@
1414

1515
package api
1616

17-
import "fmt"
17+
import (
18+
"fmt"
19+
"slices"
20+
"strings"
21+
22+
"github.com/iancoleman/strcase"
23+
)
1824

1925
// CrossReference fills out the cross-references in `model` that the parser(s)
2026
// missed.
@@ -83,5 +89,103 @@ func CrossReference(model *API) error {
8389
}
8490
}
8591
}
92+
enrichSamples(model)
8693
return nil
8794
}
95+
96+
// enrichSamples populates the API model with information useful for generating code samples.
97+
// This includes selecting representative enum values and optimal fields for oneof structures.
98+
func enrichSamples(model *API) {
99+
for _, e := range model.State.EnumByID {
100+
enrichEnumSamples(e)
101+
}
102+
103+
for _, m := range model.State.MessageByID {
104+
for _, o := range m.OneOfs {
105+
if len(o.Fields) > 0 {
106+
o.ExampleField = slices.MaxFunc(o.Fields, sortOneOfFieldForExamples)
107+
}
108+
}
109+
}
110+
}
111+
112+
func enrichEnumSamples(e *Enum) {
113+
// We try to pick some good enum values to show in examples.
114+
// - We pick values that are not deprecated.
115+
// - We don't pick the default value (Number 0).
116+
// - We try to avoid duplicates (e.g. FULL vs full).
117+
var goodValues []*EnumValue
118+
seen := make(map[string]bool)
119+
120+
for _, ev := range e.Values {
121+
// A simple heuristic to avoid duplicates.
122+
// This is not perfect, but it should handle the most common cases.
123+
name := strcase.ToCamel(strings.ToLower(ev.Name))
124+
if seen[name] {
125+
continue
126+
}
127+
seen[name] = true
128+
129+
if !ev.Deprecated && ev.Number != 0 {
130+
goodValues = append(goodValues, ev)
131+
}
132+
}
133+
// If we couldn't find any good enum values for examples, then we pick from all enum values.
134+
if len(goodValues) == 0 {
135+
// Reset seen map and try again without filters, but still deduplicating.
136+
seen = make(map[string]bool)
137+
for _, ev := range e.Values {
138+
name := strcase.ToCamel(strings.ToLower(ev.Name))
139+
if seen[name] {
140+
continue
141+
}
142+
seen[name] = true
143+
goodValues = append(goodValues, ev)
144+
}
145+
}
146+
// We pick at most 3 values as samples do not need to be exhaustive.
147+
goodValues = goodValues[:min(3, len(goodValues))]
148+
149+
e.ValuesForExamples = make([]*SampleValue, len(goodValues))
150+
for i, ev := range goodValues {
151+
e.ValuesForExamples[i] = &SampleValue{
152+
EnumValue: ev,
153+
Index: i,
154+
}
155+
}
156+
}
157+
158+
// sortOneOfFieldForExamples is used to select the "best" field for an example.
159+
//
160+
// Fields are lexicographically sorted by the tuple:
161+
//
162+
// (f.Deprecated, f.Map, f.Repeated, f.Message != nil)
163+
//
164+
// Where `false` values are preferred over `true` values. That is, we prefer
165+
// fields that are **not** deprecated, but if both fields have the same
166+
// `Deprecated` value then we prefer the field that is **not** a map, and so on.
167+
//
168+
// The return value is either -1, 0, or 1 to use in the standard library sorting
169+
// functions.
170+
func sortOneOfFieldForExamples(f1, f2 *Field) int {
171+
compare := func(a, b bool) int {
172+
switch {
173+
case a == b:
174+
return 0
175+
case a:
176+
return -1
177+
default:
178+
return 1
179+
}
180+
}
181+
if v := compare(f1.Deprecated, f2.Deprecated); v != 0 {
182+
return v
183+
}
184+
if v := compare(f1.Map, f2.Map); v != 0 {
185+
return v
186+
}
187+
if v := compare(f1.Repeated, f2.Repeated); v != 0 {
188+
return v
189+
}
190+
return compare(f1.MessageType != nil, f2.MessageType != nil)
191+
}

internal/sidekick/api/xref_test.go

Lines changed: 204 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@ package api
1717
import (
1818
"fmt"
1919
"testing"
20+
21+
"github.com/google/go-cmp/cmp"
22+
"github.com/google/go-cmp/cmp/cmpopts"
2023
)
2124

2225
func TestCrossReferenceOneOfs(t *testing.T) {
@@ -214,3 +217,204 @@ func TestCrossReferenceService(t *testing.T) {
214217
t.Errorf("mismatched model, got=%v, want=%v", mixin.Model, model)
215218
}
216219
}
220+
221+
func TestEnrichSamplesEnumValues(t *testing.T) {
222+
v_good1 := &EnumValue{Name: "GOOD_1", Number: 1}
223+
v_good2 := &EnumValue{Name: "GOOD_2", Number: 2}
224+
v_good3 := &EnumValue{Name: "GOOD_3", Number: 3}
225+
v_good4 := &EnumValue{Name: "GOOD_4", Number: 4}
226+
v_bad_deprecated := &EnumValue{Name: "BAD_DEPRECATED", Number: 5, Deprecated: true}
227+
v_bad_default := &EnumValue{Name: "BAD_DEFAULT", Number: 0}
228+
229+
testCases := []struct {
230+
name string
231+
values []*EnumValue
232+
wantExamples []*SampleValue
233+
}{
234+
{
235+
name: "more than 3 good values",
236+
values: []*EnumValue{v_good1, v_good2, v_good3, v_good4},
237+
wantExamples: []*SampleValue{
238+
{EnumValue: v_good1, Index: 0},
239+
{EnumValue: v_good2, Index: 1},
240+
{EnumValue: v_good3, Index: 2},
241+
},
242+
},
243+
{
244+
name: "less than 3 good values",
245+
values: []*EnumValue{v_good1, v_good2, v_bad_deprecated},
246+
wantExamples: []*SampleValue{
247+
{EnumValue: v_good1, Index: 0},
248+
{EnumValue: v_good2, Index: 1},
249+
},
250+
},
251+
{
252+
name: "no good values",
253+
values: []*EnumValue{v_bad_default, v_bad_deprecated},
254+
wantExamples: []*SampleValue{
255+
{EnumValue: v_bad_default, Index: 0},
256+
{EnumValue: v_bad_deprecated, Index: 1},
257+
},
258+
},
259+
{
260+
name: "no values",
261+
values: []*EnumValue{},
262+
wantExamples: []*SampleValue{},
263+
},
264+
{
265+
name: "mixed good and bad values",
266+
values: []*EnumValue{v_bad_default, v_good1, v_bad_deprecated, v_good2},
267+
wantExamples: []*SampleValue{
268+
{EnumValue: v_good1, Index: 0},
269+
{EnumValue: v_good2, Index: 1},
270+
},
271+
},
272+
}
273+
274+
for _, tc := range testCases {
275+
t.Run(tc.name, func(t *testing.T) {
276+
enum := &Enum{
277+
Name: "TestEnum",
278+
ID: ".test.v1.TestEnum",
279+
Package: "test.v1",
280+
Values: tc.values,
281+
}
282+
model := NewTestAPI([]*Message{}, []*Enum{enum}, []*Service{})
283+
if err := CrossReference(model); err != nil {
284+
t.Fatalf("CrossReference() failed: %v", err)
285+
}
286+
287+
got := enum.ValuesForExamples
288+
if diff := cmp.Diff(tc.wantExamples, got, cmpopts.IgnoreFields(EnumValue{}, "Parent")); diff != "" {
289+
t.Errorf("mismatch in ValuesForExamples (-want, +got)\n:%s", diff)
290+
}
291+
})
292+
}
293+
}
294+
295+
func TestEnrichSamplesOneOfExampleField(t *testing.T) {
296+
deprecated := &Field{
297+
Name: "deprecated_field",
298+
ID: ".test.Message.deprecated_field",
299+
Typez: STRING_TYPE,
300+
IsOneOf: true,
301+
Deprecated: true,
302+
}
303+
mapMessage := &Message{
304+
Name: "$map<string, string>",
305+
ID: "$map<string, string>",
306+
IsMap: true,
307+
Fields: []*Field{
308+
{Name: "key", ID: "$map<string, string>.key", Typez: STRING_TYPE},
309+
{Name: "value", ID: "$map<string, string>.value", Typez: STRING_TYPE},
310+
},
311+
}
312+
mapField := &Field{
313+
Name: "map_field",
314+
ID: ".test.Message.map_field",
315+
Typez: MESSAGE_TYPE,
316+
TypezID: "$map<string, string>",
317+
IsOneOf: true,
318+
Map: true,
319+
}
320+
repeated := &Field{
321+
Name: "repeated_field",
322+
ID: ".test.Message.repeated_field",
323+
Typez: STRING_TYPE,
324+
Repeated: true,
325+
IsOneOf: true,
326+
}
327+
scalar := &Field{
328+
Name: "scalar_field",
329+
ID: ".test.Message.scalar_field",
330+
Typez: INT32_TYPE,
331+
IsOneOf: true,
332+
}
333+
messageField := &Field{
334+
Name: "message_field",
335+
ID: ".test.Message.message_field",
336+
Typez: MESSAGE_TYPE,
337+
TypezID: ".test.OneMessage",
338+
IsOneOf: true,
339+
}
340+
anotherMessageField := &Field{
341+
Name: "another_message_field",
342+
ID: ".test.Message.another_message_field",
343+
Typez: MESSAGE_TYPE,
344+
TypezID: ".test.AnotherMessage",
345+
IsOneOf: true,
346+
}
347+
348+
testCases := []struct {
349+
name string
350+
fields []*Field
351+
want *Field
352+
}{
353+
{
354+
name: "all types",
355+
fields: []*Field{deprecated, mapField, repeated, scalar, messageField},
356+
want: scalar,
357+
},
358+
{
359+
name: "no primitives",
360+
fields: []*Field{deprecated, mapField, repeated, messageField},
361+
want: messageField,
362+
},
363+
{
364+
name: "only scalars and messages",
365+
fields: []*Field{messageField, scalar, anotherMessageField},
366+
want: scalar,
367+
},
368+
{
369+
name: "no scalars",
370+
fields: []*Field{deprecated, mapField, repeated},
371+
want: repeated,
372+
},
373+
{
374+
name: "only map and deprecated",
375+
fields: []*Field{deprecated, mapField},
376+
want: mapField,
377+
},
378+
{
379+
name: "only deprecated",
380+
fields: []*Field{deprecated},
381+
want: deprecated,
382+
},
383+
}
384+
385+
for _, tc := range testCases {
386+
t.Run(tc.name, func(t *testing.T) {
387+
group := &OneOf{
388+
Name: "test_oneof",
389+
ID: ".test.Message.test_oneof",
390+
Fields: tc.fields,
391+
}
392+
message := &Message{
393+
Name: "Message",
394+
ID: ".test.Message",
395+
Package: "test",
396+
Fields: tc.fields,
397+
OneOfs: []*OneOf{group},
398+
}
399+
oneMesage := &Message{
400+
Name: "OneMessage",
401+
ID: ".test.OneMessage",
402+
Package: "test",
403+
}
404+
anotherMessage := &Message{
405+
Name: "AnotherMessage",
406+
ID: ".test.AnotherMessage",
407+
Package: "test",
408+
}
409+
model := NewTestAPI([]*Message{message, oneMesage, anotherMessage, mapMessage}, []*Enum{}, []*Service{})
410+
if err := CrossReference(model); err != nil {
411+
t.Fatal(err)
412+
}
413+
414+
got := group.ExampleField
415+
if diff := cmp.Diff(tc.want, got); diff != "" {
416+
t.Errorf("mismatch in ExampleField (-want, +got)\n:%s", diff)
417+
}
418+
})
419+
}
420+
}

0 commit comments

Comments
 (0)