Skip to content

Commit 48f2e5a

Browse files
Add enums_as_ints flag to protoc-gen-swagger (#1186)
* feat: add enumAsInts flag * Update protoc-gen-swagger/genswagger/template.go Co-Authored-By: Johan Brandhorst <[email protected]> * fix: typo * feat: add comment * feat: add test Co-authored-by: Johan Brandhorst <[email protected]>
1 parent 5e68b7e commit 48f2e5a

File tree

4 files changed

+228
-2
lines changed

4 files changed

+228
-2
lines changed

protoc-gen-grpc-gateway/descriptor/registry.go

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,9 @@ type Registry struct {
7676
// in your protofile comments
7777
useGoTemplate bool
7878

79+
// enumsAsInts render enum as integer, as opposed to string
80+
enumsAsInts bool
81+
7982
// disableDefaultErrors disables the generation of the default error types.
8083
// This is useful for users who have defined custom error handling.
8184
disableDefaultErrors bool
@@ -487,6 +490,16 @@ func (r *Registry) GetUseGoTemplate() bool {
487490
return r.useGoTemplate
488491
}
489492

493+
// SetEnumsAsInts set enumsAsInts
494+
func (r *Registry) SetEnumsAsInts(enumsAsInts bool) {
495+
r.enumsAsInts = enumsAsInts
496+
}
497+
498+
// GetEnumsAsInts returns enumsAsInts
499+
func (r *Registry) GetEnumsAsInts() bool {
500+
return r.enumsAsInts
501+
}
502+
490503
// SetDisableDefaultErrors sets disableDefaultErrors
491504
func (r *Registry) SetDisableDefaultErrors(use bool) {
492505
r.disableDefaultErrors = use

protoc-gen-swagger/genswagger/template.go

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,13 @@ func listEnumNames(enum *descriptor.Enum) (names []string) {
9292
return names
9393
}
9494

95+
func listEnumNumbers(enum *descriptor.Enum) (numbers []string) {
96+
for _, value := range enum.GetValue() {
97+
numbers = append(numbers, strconv.Itoa(int(value.GetNumber())))
98+
}
99+
return
100+
}
101+
95102
func getEnumDefault(enum *descriptor.Enum) string {
96103
for _, value := range enum.GetValue() {
97104
if value.GetNumber() == 0 {
@@ -182,10 +189,19 @@ func queryParams(message *descriptor.Message, field *descriptor.Field, prefix st
182189
Type: "string",
183190
Enum: listEnumNames(enum),
184191
}
192+
if reg.GetEnumsAsInts() {
193+
param.Items.Type = "integer"
194+
param.Items.Enum = listEnumNumbers(enum)
195+
}
185196
} else {
186197
param.Type = "string"
187198
param.Enum = listEnumNames(enum)
188199
param.Default = getEnumDefault(enum)
200+
if reg.GetEnumsAsInts() {
201+
param.Type = "integer"
202+
param.Enum = listEnumNumbers(enum)
203+
param.Default = "0"
204+
}
189205
}
190206
valueComments := enumValueProtoComments(reg, enum)
191207
if valueComments != "" {
@@ -531,6 +547,12 @@ func renderEnumerationsAsDefinition(enums enumMap, d swaggerDefinitionsObject, r
531547
Default: defaultValue,
532548
},
533549
}
550+
if reg.GetEnumsAsInts() {
551+
enumSchemaObject.Type = "integer"
552+
enumSchemaObject.Format = "int32"
553+
enumSchemaObject.Default = "0"
554+
enumSchemaObject.Enum = listEnumNumbers(enum)
555+
}
534556
if err := updateSwaggerDataFromComments(reg, &enumSchemaObject, enum, enumComments, false); err != nil {
535557
panic(err)
536558
}
@@ -737,13 +759,18 @@ func renderServices(services []*descriptor.Service, paths swaggerPathsObject, re
737759
return fmt.Errorf("only primitive and well-known types are allowed in path parameters")
738760
}
739761
case pbdescriptor.FieldDescriptorProto_TYPE_ENUM:
740-
paramType = "string"
741-
paramFormat = ""
742762
enum, err := reg.LookupEnum("", parameter.Target.GetTypeName())
743763
if err != nil {
744764
return err
745765
}
766+
paramType = "string"
767+
paramFormat = ""
746768
enumNames = listEnumNames(enum)
769+
if reg.GetEnumsAsInts() {
770+
paramType = "integer"
771+
paramFormat = ""
772+
enumNames = listEnumNumbers(enum)
773+
}
747774
schema := schemaOfField(parameter.Target, reg, customRefs)
748775
desc = schema.Description
749776
defaultValue = schema.Default
@@ -1470,6 +1497,9 @@ func enumValueProtoComments(reg *descriptor.Registry, enum *descriptor.Enum) str
14701497
var comments []string
14711498
for idx, value := range enum.GetValue() {
14721499
name := value.GetName()
1500+
if reg.GetEnumsAsInts() {
1501+
name = strconv.Itoa(int(value.GetNumber()))
1502+
}
14731503
str := protoComments(reg, enum.File, enum.Outers, "EnumType", int32(enum.Index), protoPath, int32(idx))
14741504
if str != "" {
14751505
comments = append(comments, name+": "+str)

protoc-gen-swagger/genswagger/template_test.go

Lines changed: 181 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,187 @@ func reqFromFile(f *descriptor.File) *plugin.CodeGeneratorRequest {
4646
}
4747
}
4848

49+
func TestMessageToQueryParametersWithEnumAsInt(t *testing.T) {
50+
type test struct {
51+
MsgDescs []*protodescriptor.DescriptorProto
52+
Message string
53+
Params []swaggerParameterObject
54+
}
55+
56+
tests := []test{
57+
{
58+
MsgDescs: []*protodescriptor.DescriptorProto{
59+
&protodescriptor.DescriptorProto{
60+
Name: proto.String("ExampleMessage"),
61+
Field: []*protodescriptor.FieldDescriptorProto{
62+
{
63+
Name: proto.String("a"),
64+
Type: protodescriptor.FieldDescriptorProto_TYPE_STRING.Enum(),
65+
Number: proto.Int32(1),
66+
},
67+
{
68+
Name: proto.String("b"),
69+
Type: protodescriptor.FieldDescriptorProto_TYPE_DOUBLE.Enum(),
70+
Number: proto.Int32(2),
71+
},
72+
{
73+
Name: proto.String("c"),
74+
Type: protodescriptor.FieldDescriptorProto_TYPE_STRING.Enum(),
75+
Label: protodescriptor.FieldDescriptorProto_LABEL_REPEATED.Enum(),
76+
Number: proto.Int32(3),
77+
},
78+
},
79+
},
80+
},
81+
Message: "ExampleMessage",
82+
Params: []swaggerParameterObject{
83+
swaggerParameterObject{
84+
Name: "a",
85+
In: "query",
86+
Required: false,
87+
Type: "string",
88+
},
89+
swaggerParameterObject{
90+
Name: "b",
91+
In: "query",
92+
Required: false,
93+
Type: "number",
94+
Format: "double",
95+
},
96+
swaggerParameterObject{
97+
Name: "c",
98+
In: "query",
99+
Required: false,
100+
Type: "array",
101+
CollectionFormat: "multi",
102+
},
103+
},
104+
},
105+
{
106+
MsgDescs: []*protodescriptor.DescriptorProto{
107+
&protodescriptor.DescriptorProto{
108+
Name: proto.String("ExampleMessage"),
109+
Field: []*protodescriptor.FieldDescriptorProto{
110+
{
111+
Name: proto.String("nested"),
112+
Type: protodescriptor.FieldDescriptorProto_TYPE_MESSAGE.Enum(),
113+
TypeName: proto.String(".example.Nested"),
114+
Number: proto.Int32(1),
115+
},
116+
},
117+
},
118+
&protodescriptor.DescriptorProto{
119+
Name: proto.String("Nested"),
120+
Field: []*protodescriptor.FieldDescriptorProto{
121+
{
122+
Name: proto.String("a"),
123+
Type: protodescriptor.FieldDescriptorProto_TYPE_STRING.Enum(),
124+
Number: proto.Int32(1),
125+
},
126+
{
127+
Name: proto.String("deep"),
128+
Type: protodescriptor.FieldDescriptorProto_TYPE_MESSAGE.Enum(),
129+
TypeName: proto.String(".example.Nested.DeepNested"),
130+
Number: proto.Int32(2),
131+
},
132+
},
133+
NestedType: []*protodescriptor.DescriptorProto{{
134+
Name: proto.String("DeepNested"),
135+
Field: []*protodescriptor.FieldDescriptorProto{
136+
{
137+
Name: proto.String("b"),
138+
Type: protodescriptor.FieldDescriptorProto_TYPE_STRING.Enum(),
139+
Number: proto.Int32(1),
140+
},
141+
{
142+
Name: proto.String("c"),
143+
Type: protodescriptor.FieldDescriptorProto_TYPE_ENUM.Enum(),
144+
TypeName: proto.String(".example.Nested.DeepNested.DeepEnum"),
145+
Number: proto.Int32(2),
146+
},
147+
},
148+
EnumType: []*protodescriptor.EnumDescriptorProto{
149+
{
150+
Name: proto.String("DeepEnum"),
151+
Value: []*protodescriptor.EnumValueDescriptorProto{
152+
{Name: proto.String("FALSE"), Number: proto.Int32(0)},
153+
{Name: proto.String("TRUE"), Number: proto.Int32(1)},
154+
},
155+
},
156+
},
157+
}},
158+
},
159+
},
160+
Message: "ExampleMessage",
161+
Params: []swaggerParameterObject{
162+
swaggerParameterObject{
163+
Name: "nested.a",
164+
In: "query",
165+
Required: false,
166+
Type: "string",
167+
},
168+
swaggerParameterObject{
169+
Name: "nested.deep.b",
170+
In: "query",
171+
Required: false,
172+
Type: "string",
173+
},
174+
swaggerParameterObject{
175+
Name: "nested.deep.c",
176+
In: "query",
177+
Required: false,
178+
Type: "integer",
179+
Enum: []string{"0", "1"},
180+
Default: "0",
181+
},
182+
},
183+
},
184+
}
185+
186+
for _, test := range tests {
187+
reg := descriptor.NewRegistry()
188+
reg.SetEnumsAsInts(true)
189+
msgs := []*descriptor.Message{}
190+
for _, msgdesc := range test.MsgDescs {
191+
msgs = append(msgs, &descriptor.Message{DescriptorProto: msgdesc})
192+
}
193+
file := descriptor.File{
194+
FileDescriptorProto: &protodescriptor.FileDescriptorProto{
195+
SourceCodeInfo: &protodescriptor.SourceCodeInfo{},
196+
Name: proto.String("example.proto"),
197+
Package: proto.String("example"),
198+
Dependency: []string{},
199+
MessageType: test.MsgDescs,
200+
Service: []*protodescriptor.ServiceDescriptorProto{},
201+
},
202+
GoPkg: descriptor.GoPackage{
203+
Path: "example.com/path/to/example/example.pb",
204+
Name: "example_pb",
205+
},
206+
Messages: msgs,
207+
}
208+
reg.Load(&plugin.CodeGeneratorRequest{
209+
ProtoFile: []*protodescriptor.FileDescriptorProto{file.FileDescriptorProto},
210+
})
211+
212+
message, err := reg.LookupMsg("", ".example."+test.Message)
213+
if err != nil {
214+
t.Fatalf("failed to lookup message: %s", err)
215+
}
216+
params, err := messageToQueryParameters(message, reg, []descriptor.Parameter{})
217+
if err != nil {
218+
t.Fatalf("failed to convert message to query parameters: %s", err)
219+
}
220+
// avoid checking Items for array types
221+
for i := range params {
222+
params[i].Items = nil
223+
}
224+
if !reflect.DeepEqual(params, test.Params) {
225+
t.Errorf("expected %v, got %v", test.Params, params)
226+
}
227+
}
228+
}
229+
49230
func TestMessageToQueryParameters(t *testing.T) {
50231
type test struct {
51232
MsgDescs []*protodescriptor.DescriptorProto

protoc-gen-swagger/main.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ var (
2929
useFQNForSwaggerName = flag.Bool("fqn_for_swagger_name", false, "if set, the object's swagger names will use the fully qualify name from the proto definition (ie my.package.MyMessage.MyInnerMessage")
3030
useGoTemplate = flag.Bool("use_go_templates", false, "if set, you can use Go templates in protofile comments")
3131
disableDefaultErrors = flag.Bool("disable_default_errors", false, "if set, disables generation of default errors. This is useful if you have defined custom error handling")
32+
enumsAsInts = flag.Bool("enums_as_ints", false, "whether to render enum values as integers, as opposed to string values")
3233
)
3334

3435
// Variables set by goreleaser at build time
@@ -81,6 +82,7 @@ func main() {
8182
reg.SetIncludePackageInTags(*includePackageInTags)
8283
reg.SetUseFQNForSwaggerName(*useFQNForSwaggerName)
8384
reg.SetUseGoTemplate(*useGoTemplate)
85+
reg.SetEnumsAsInts(*enumsAsInts)
8486
reg.SetDisableDefaultErrors(*disableDefaultErrors)
8587
if err := reg.SetRepeatedPathParamSeparator(*repeatedPathParamSeparator); err != nil {
8688
emitError(err)

0 commit comments

Comments
 (0)