Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 2 additions & 73 deletions generators/rust/dynamic-snippets/src/EndpointSnippetGenerator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ export class EndpointSnippetGenerator {
snippet: FernIr.dynamic.EndpointSnippetRequest;
}): string[] {
// Get use statements
const useStatements = this.getUseStatements({ endpoint, snippet });
const useStatements = this.getUseStatements();

// Create the main function body
const mainBody = rust.CodeBlock.fromStatements([
Expand Down Expand Up @@ -99,52 +99,9 @@ export class EndpointSnippetGenerator {
return components;
}

private getUseStatements({
endpoint,
snippet
}: {
endpoint: FernIr.dynamic.Endpoint;
snippet: FernIr.dynamic.EndpointSnippetRequest;
}): rust.UseStatement[] {
const stdImports = new Set<string>();
const chronoImports = new Set<string>();
const uuidImports = new Set<string>();

// Collect types used in snippet values that require std/chrono/uuid imports
this.collectSnippetTypeImports(snippet, new Set<string>(), stdImports, chronoImports, uuidImports);

private getUseStatements(): rust.UseStatement[] {
const useStatements: rust.UseStatement[] = [];

// Add standard library imports if needed
if (stdImports.size > 0) {
useStatements.push(
new rust.UseStatement({
path: "std::collections",
items: Array.from(stdImports)
})
);
}

// Add chrono imports if needed
if (chronoImports.size > 0) {
useStatements.push(
new rust.UseStatement({
path: "chrono",
items: Array.from(chronoImports)
})
);
}

// Add UUID imports if needed
if (uuidImports.size > 0) {
useStatements.push(
new rust.UseStatement({
path: "uuid",
items: Array.from(uuidImports)
})
);
}

// Use prelude import for all crate types
useStatements.push(
new rust.UseStatement({
Expand All @@ -156,34 +113,6 @@ export class EndpointSnippetGenerator {
return useStatements;
}

// New method to collect types from snippet values
private collectSnippetTypeImports(
snippet: FernIr.dynamic.EndpointSnippetRequest,
imports: Set<string>,
stdImports: Set<string>,
chronoImports: Set<string>,
uuidImports: Set<string>
): void {
// Collect types from request body if present
if (snippet.requestBody != null) {
this.collectTypesFromValue(snippet.requestBody, imports, stdImports, chronoImports, uuidImports);
}

// Collect types from query parameters
if (snippet.queryParameters != null) {
Object.values(snippet.queryParameters).forEach((value) => {
this.collectTypesFromValue(value, imports, stdImports, chronoImports, uuidImports);
});
}

// Collect types from headers
if (snippet.headers != null) {
Object.values(snippet.headers).forEach((value) => {
this.collectTypesFromValue(value, imports, stdImports, chronoImports, uuidImports);
});
}
}

// Helper to collect type imports from a value by analyzing its structure
private collectTypesFromValue(
value: unknown,
Expand Down
89 changes: 2 additions & 87 deletions generators/rust/model/src/alias/AliasGenerator.ts
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import { RelativeFilePath } from "@fern-api/fs-utils";
import { RustFile } from "@fern-api/rust-base";
import { Attribute, PUBLIC, rust } from "@fern-api/rust-codegen";
import { AliasTypeDeclaration, TypeDeclaration, TypeReference } from "@fern-fern/ir-sdk/api";
import { AliasTypeDeclaration, TypeDeclaration } from "@fern-fern/ir-sdk/api";
import { generateRustTypeForTypeReference } from "../converters";
import { ModelGeneratorContext } from "../ModelGeneratorContext";
import { isChronoType, isCollectionType, isUuidType, typeSupportsHashAndEq } from "../utils/primitiveTypeUtils";
import { typeSupportsHashAndEq } from "../utils/primitiveTypeUtils";

export class AliasGenerator {
private readonly typeDeclaration: TypeDeclaration;
Expand Down Expand Up @@ -52,45 +52,6 @@ export class AliasGenerator {
return writer.toString();
}

private writeUseStatements(writer: rust.Writer): void {
writer.writeLine("use serde::{Deserialize, Serialize};");

// Add additional use statements based on the inner type
this.writeAdditionalUseStatements(writer);
}

private writeAdditionalUseStatements(writer: rust.Writer): void {
const innerType = this.aliasTypeDeclaration.aliasOf;

// Add imports for custom named types FIRST
const customTypes = this.getCustomTypesUsedInAlias();
customTypes.forEach((typeName) => {
const modulePath = this.context.getModulePathForType(typeName.snakeCase.unsafeName);
const moduleNameEscaped = this.context.escapeRustKeyword(modulePath);
writer.writeLine(`use crate::${moduleNameEscaped}::${typeName.pascalCase.unsafeName};`);
});

// Add chrono if aliasing a datetime
if (isChronoType(innerType)) {
writer.writeLine("use chrono::{DateTime, Utc};");
}

// Add uuid if aliasing a UUID
if (isUuidType(innerType)) {
writer.writeLine("use uuid::Uuid;");
}

// Add collections if aliasing a map or set
if (isCollectionType(innerType)) {
writer.writeLine("use std::collections::HashMap;");
}

// TODO: @iamnamananand996 build to use serde_json::Value ---> Value directly
// if (hasJsonValueFields(properties)) {
// writer.writeLine("use serde_json::Value;");
// }
}

private generateNewtypeForTypeDeclaration(): rust.NewtypeStruct {
return rust.newtypeStruct({
name: this.context.getUniqueTypeNameForDeclaration(this.typeDeclaration),
Expand Down Expand Up @@ -124,50 +85,4 @@ export class AliasGenerator {
// Check if the aliased type can support Hash and Eq derives
return typeSupportsHashAndEq(this.aliasTypeDeclaration.aliasOf, this.context);
}

private getCustomTypesUsedInAlias(): {
snakeCase: { unsafeName: string };
pascalCase: { unsafeName: string };
}[] {
const customTypeNames: {
snakeCase: { unsafeName: string };
pascalCase: { unsafeName: string };
}[] = [];
const visited = new Set<string>();

const extractNamedTypesRecursively = (typeRef: TypeReference) => {
if (typeRef.type === "named") {
const typeName = typeRef.name.originalName;
if (!visited.has(typeName)) {
visited.add(typeName);
customTypeNames.push({
snakeCase: { unsafeName: typeRef.name.snakeCase.unsafeName },
pascalCase: { unsafeName: typeRef.name.pascalCase.unsafeName }
});
}
} else if (typeRef.type === "container") {
typeRef.container._visit({
list: (listType: TypeReference) => extractNamedTypesRecursively(listType),
set: (setType: TypeReference) => extractNamedTypesRecursively(setType),
optional: (optionalType: TypeReference) => extractNamedTypesRecursively(optionalType),
nullable: (nullableType: TypeReference) => extractNamedTypesRecursively(nullableType),
map: (mapType) => {
extractNamedTypesRecursively(mapType.keyType);
extractNamedTypesRecursively(mapType.valueType);
},
literal: () => {
// No named types in literals
},
_other: () => {
// Unknown container type
}
});
}
};

// Analyze the aliased type
extractNamedTypesRecursively(this.aliasTypeDeclaration.aliasOf);

return customTypeNames;
}
}
5 changes: 0 additions & 5 deletions generators/rust/model/src/enum/EnumGenerator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,6 @@ export class EnumGenerator {
return writer.toString();
}

private writeUseStatements(writer: rust.Writer): void {
writer.writeLine("use serde::{Deserialize, Serialize};");
writer.writeLine("use std::fmt;");
}

private generateEnumForTypeDeclaration(): rust.Enum {
return rust.enum_({
name: this.context.getUniqueTypeNameForDeclaration(this.typeDeclaration),
Expand Down
87 changes: 1 addition & 86 deletions generators/rust/model/src/object/StructGenerator.ts
Original file line number Diff line number Diff line change
@@ -1,25 +1,15 @@
import { RelativeFilePath } from "@fern-api/fs-utils";
import { RustFile } from "@fern-api/rust-base";
import { Attribute, PUBLIC, rust } from "@fern-api/rust-codegen";

import { ObjectProperty, ObjectTypeDeclaration, TypeDeclaration } from "@fern-fern/ir-sdk/api";

import { ModelGeneratorContext } from "../ModelGeneratorContext";
import { namedTypeSupportsHashAndEq, namedTypeSupportsPartialEq } from "../utils/primitiveTypeUtils";
import { isFieldRecursive } from "../utils/recursiveTypeUtils";
import {
canDeriveHashAndEq,
canDerivePartialEq,
generateFieldAttributes,
generateFieldType,
getCustomTypesUsedInFields,
hasBigIntFields,
hasDateFields,
hasDateTimeOnlyFields,
hasFloatingPointSets,
hasHashMapFields,
hasHashSetFields,
hasUuidFields
generateFieldType
} from "../utils/structUtils";

export class StructGenerator {
Expand Down Expand Up @@ -68,81 +58,6 @@ export class StructGenerator {
return writer.toString();
}

private writeUseStatements(writer: rust.Writer): void {
// Add imports for custom named types referenced in fields FIRST
const customTypes = getCustomTypesUsedInFields(
this.objectTypeDeclaration.properties,
this.typeDeclaration.name.name.pascalCase.unsafeName
);
customTypes.forEach((typeName) => {
const modulePath = this.context.getModulePathForType(typeName.snakeCase.unsafeName);
const moduleNameEscaped = this.context.escapeRustKeyword(modulePath);
writer.writeLine(`use crate::${moduleNameEscaped}::${typeName.pascalCase.unsafeName};`);
});

// Add imports for parent types
if (this.objectTypeDeclaration.extends.length > 0) {
this.objectTypeDeclaration.extends.forEach((parentType) => {
// Use getUniqueTypeNameForReference to get the correct type name with fernFilepath prefix
const parentTypeName = this.context.getUniqueTypeNameForReference(parentType);
const modulePath = this.context.getModulePathForType(parentType.name.snakeCase.unsafeName);
const moduleNameEscaped = this.context.escapeRustKeyword(modulePath);
writer.writeLine(`use crate::${moduleNameEscaped}::${parentTypeName};`);
});
}

// Add chrono imports based on specific types needed
const hasDateOnly = hasDateFields(this.objectTypeDeclaration.properties);
const hasDateTimeOnly = hasDateTimeOnlyFields(this.objectTypeDeclaration.properties);

// TODO: @iamnamananand996 - use AST mechanism for all imports
if (hasDateOnly && hasDateTimeOnly) {
// Both date and datetime types present
writer.writeLine("use chrono::{DateTime, NaiveDate, Utc};");
} else if (hasDateOnly) {
// Only date type present, import NaiveDate only
writer.writeLine("use chrono::NaiveDate;");
} else if (hasDateTimeOnly) {
// Only datetime type present, import DateTime and Utc only
writer.writeLine("use chrono::{DateTime, Utc};");
}

// Add std::collections imports based on specific collection types used
const needsHashMap = hasHashMapFields(this.objectTypeDeclaration.properties);
const needsHashSet = hasHashSetFields(this.objectTypeDeclaration.properties);

if (needsHashMap && needsHashSet) {
writer.writeLine("use std::collections::{HashMap, HashSet};");
} else if (needsHashMap) {
writer.writeLine("use std::collections::HashMap;");
} else if (needsHashSet) {
writer.writeLine("use std::collections::HashSet;");
}

// Add ordered_float if we have floating-point sets
if (hasFloatingPointSets(this.objectTypeDeclaration.properties)) {
writer.writeLine("use ordered_float::OrderedFloat;");
}

// Add uuid if we have UUID fields
if (hasUuidFields(this.objectTypeDeclaration.properties)) {
writer.writeLine("use uuid::Uuid;");
}

// Add num_bigint if we have BigInt fields
if (hasBigIntFields(this.objectTypeDeclaration.properties)) {
writer.writeLine("use num_bigint::BigInt;");
}

// TODO: @iamnamananand996 build to use serde_json::Value ---> Value directly
// if (hasJsonValueFields(properties)) {
// writer.writeLine("use serde_json::Value;");
// }

// Add serde imports LAST
writer.writeLine("use serde::{Deserialize, Serialize};");
}

private generateStructForTypeDeclaration(): rust.Struct {
const fields: rust.Field[] = [];

Expand Down
Loading
Loading