Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 24 additions & 2 deletions generators/go-v2/dynamic-snippets/src/EndpointSnippetGenerator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -333,8 +333,7 @@ export class EndpointSnippetGenerator {
case "header":
return values.type === "header" ? this.getConstructorHeaderAuthArg({ auth, values }) : TypeInst.nop();
case "oauth":
this.addWarning("The Go SDK doesn't support OAuth client credentials yet");
return TypeInst.nop();
return values.type === "oauth" ? this.getConstructorOAuthArg({ auth, values }) : TypeInst.nop();
case "inferred":
this.addWarning("The Go SDK Generator does not support Inferred auth scheme yet");
return TypeInst.nop();
Expand Down Expand Up @@ -483,6 +482,29 @@ export class EndpointSnippetGenerator {
});
}

private getConstructorOAuthArg({
auth,
values
}: {
auth: FernIr.dynamic.OAuth;
values: FernIr.dynamic.OAuthValues;
}): go.AstNode {
return go.codeblock((writer) => {
writer.writeNode(
go.invokeFunc({
func: go.typeReference({
name: "WithClientCredentials",
importPath: this.context.getOptionImportPath()
}),
arguments_: [
go.TypeInstantiation.string(values.clientId),
go.TypeInstantiation.string(values.clientSecret)
]
})
);
});
}

private getConstructorHeaderArgs({
headers,
values
Expand Down
156 changes: 156 additions & 0 deletions generators/go/internal/generator/generator.go
Original file line number Diff line number Diff line change
Expand Up @@ -580,6 +580,9 @@ func (g *Generator) generate(ir *fernir.IntermediateRepresentation, mode Mode) (
files = append(files, newMultipartFile(g.coordinator))
files = append(files, newMultipartTestFile(g.coordinator))
}
if needsOAuthHelpers(ir) {
files = append(files, newOAuthFile(g.coordinator))
}
clientTestFile, err := newClientTestFile(g.config.FullImportPath, rootPackageName, g.coordinator, g.config.ClientName, g.config.ClientConstructorName)
if err != nil {
return nil, err
Expand Down Expand Up @@ -766,6 +769,7 @@ func (g *Generator) generateRootService(
ir.Errors,
g.coordinator,
)
oauthConfig := computeOAuthClientCredentialsConfig(ir, g.config.FullImportPath)
generatedClient, err := writer.WriteClient(
ir.Auth,
irService.Endpoints,
Expand All @@ -781,6 +785,7 @@ func (g *Generator) generateRootService(
g.config.InlineFileProperties,
g.config.ClientName,
g.config.ClientConstructorName,
oauthConfig,
)
if err != nil {
return nil, nil, err
Expand Down Expand Up @@ -832,6 +837,7 @@ func (g *Generator) generateService(
g.config.InlineFileProperties,
"",
"",
nil, // OAuth config is only for root client
)
if err != nil {
return nil, nil, err
Expand Down Expand Up @@ -886,6 +892,7 @@ func (g *Generator) generateServiceWithoutEndpoints(
g.config.InlineFileProperties,
"",
"",
nil, // OAuth config is only for root client
); err != nil {
return nil, err
}
Expand Down Expand Up @@ -920,6 +927,7 @@ func (g *Generator) generateRootServiceWithoutEndpoints(
ir.Errors,
g.coordinator,
)
oauthConfig := computeOAuthClientCredentialsConfig(ir, g.config.FullImportPath)
generatedClient, err := writer.WriteClient(
ir.Auth,
nil,
Expand All @@ -935,6 +943,7 @@ func (g *Generator) generateRootServiceWithoutEndpoints(
g.config.InlineFileProperties,
g.config.ClientName,
g.config.ClientConstructorName,
oauthConfig,
)
if err != nil {
return nil, nil, err
Expand Down Expand Up @@ -1266,6 +1275,14 @@ func newOptionalTestFile(coordinator *coordinator.Client) *File {
)
}

func newOAuthFile(coordinator *coordinator.Client) *File {
return NewFile(
coordinator,
"core/oauth.go",
[]byte(oauthFile),
)
}

func newQueryFile(coordinator *coordinator.Client) *File {
return NewFile(
coordinator,
Expand Down Expand Up @@ -1900,6 +1917,145 @@ func needsFileUploadHelpers(ir *fernir.IntermediateRepresentation) bool {
return false
}

// needsOAuthHelpers returns true if OAuth is used in the IR.
func needsOAuthHelpers(ir *fernir.IntermediateRepresentation) bool {
if ir.Auth == nil {
return false
}
for _, authScheme := range ir.Auth.Schemes {
if authScheme.Oauth != nil {
return true
}
}
return false
}

// computeOAuthClientCredentialsConfig extracts the OAuth client credentials configuration
// from the IR and returns a config struct that can be used to generate the token refresh code.
func computeOAuthClientCredentialsConfig(
ir *fernir.IntermediateRepresentation,
baseImportPath string,
) *OAuthClientCredentialsConfig {
if ir.Auth == nil {
return nil
}

// Find the OAuth client credentials scheme
var oauthScheme *fernir.OAuthScheme
for _, authScheme := range ir.Auth.Schemes {
if authScheme.Oauth != nil {
oauthScheme = authScheme.Oauth
break
}
}
if oauthScheme == nil || oauthScheme.Configuration == nil {
return nil
}

// Only support client credentials for now
clientCreds := oauthScheme.Configuration.ClientCredentials
if clientCreds == nil || clientCreds.TokenEndpoint == nil {
return nil
}

tokenEndpoint := clientCreds.TokenEndpoint
endpointRef := tokenEndpoint.EndpointReference
if endpointRef == nil {
return nil
}

// Find the endpoint by ID to get the method name
var endpoint *fernir.HttpEndpoint
for _, service := range ir.Services {
for _, ep := range service.Endpoints {
if ep.Id == endpointRef.EndpointId {
endpoint = ep
break
}
}
if endpoint != nil {
break
}
}
if endpoint == nil {
return nil
}

// Determine the token client import path based on the subpackage
var tokenClientImportPath string
if endpointRef.SubpackageId != nil {
subpackage := ir.Subpackages[*endpointRef.SubpackageId]
if subpackage != nil {
tokenClientImportPath = packagePathToImportPath(baseImportPath, packagePathForClient(subpackage.FernFilepath))
}
}
if tokenClientImportPath == "" {
// Token endpoint is in the root package - use the client package
tokenClientImportPath = baseImportPath + "/client"
}

// Get the endpoint method name
tokenEndpointMethod := endpoint.Name.PascalCase.UnsafeName

// Get the request type info
// The request type is typically in the root package (fern)
requestImportPath := baseImportPath
var requestType string
if endpoint.SdkRequest != nil && endpoint.SdkRequest.Shape != nil {
if wrapper := endpoint.SdkRequest.Shape.Wrapper; wrapper != nil {
requestType = wrapper.WrapperName.PascalCase.UnsafeName
}
}
if requestType == "" {
// Fallback: construct from endpoint name
requestType = endpoint.Name.PascalCase.UnsafeName + "Request"
}

// Get request property field names
reqProps := tokenEndpoint.RequestProperties
if reqProps == nil || reqProps.ClientId == nil || reqProps.ClientSecret == nil {
return nil
}

clientIDFieldName := "ClientId"
clientSecretFieldName := "ClientSecret"
if reqProps.ClientId.Property != nil && reqProps.ClientId.Property.Body != nil {
clientIDFieldName = reqProps.ClientId.Property.Body.Name.Name.PascalCase.UnsafeName
}
if reqProps.ClientSecret.Property != nil && reqProps.ClientSecret.Property.Body != nil {
clientSecretFieldName = reqProps.ClientSecret.Property.Body.Name.Name.PascalCase.UnsafeName
}

// Get response property field names
respProps := tokenEndpoint.ResponseProperties
if respProps == nil || respProps.AccessToken == nil {
return nil
}

accessTokenFieldName := "AccessToken"
if respProps.AccessToken.Property != nil {
accessTokenFieldName = respProps.AccessToken.Property.Name.Name.PascalCase.UnsafeName
}

hasExpiresIn := respProps.ExpiresIn != nil
expiresInFieldName := "ExpiresIn"
if hasExpiresIn && respProps.ExpiresIn.Property != nil {
expiresInFieldName = respProps.ExpiresIn.Property.Name.Name.PascalCase.UnsafeName
}

return &OAuthClientCredentialsConfig{
TokenClientImportPath: tokenClientImportPath,
TokenEndpointMethod: tokenEndpointMethod,
RequestImportPath: requestImportPath,
RequestType: requestType,
ClientIDFieldName: clientIDFieldName,
ClientSecretFieldName: clientSecretFieldName,
AccessTokenFieldName: accessTokenFieldName,
ExpiresInFieldName: expiresInFieldName,
HasExpiresIn: hasExpiresIn,
}
}

func isReservedFilename(filename string) bool {
_, ok := reservedFilenames[filename]
return ok
Expand Down
Loading
Loading