Skip to content

Commit 85c5053

Browse files
committed
Fix duplicate struct/function generation; add test to ensure unique code output
1 parent 00503f1 commit 85c5053

File tree

3 files changed

+302
-62
lines changed

3 files changed

+302
-62
lines changed

internal/output/write.go

Lines changed: 223 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -9,102 +9,101 @@ import (
99
"io/fs"
1010
"os"
1111
"path/filepath"
12+
"strings"
1213
)
1314

1415
// WriteDataSources uses the packageName to determine whether to create a directory and package per data source.
1516
// If packageName is an empty string, this indicates that the flag was not set, and the default behaviour is
1617
// then to create a package and directory per data source. If packageName is set then all generated code is
1718
// placed into the same directory and package.
1819
func WriteDataSources(dataSourcesSchema, dataSourcesModels, customTypeValue, dataSourcesToFrom map[string][]byte, outputDir, packageName string) error {
19-
for k, v := range dataSourcesSchema {
20-
dirName := ""
20+
for k, v := range dataSourcesSchema {
21+
dirName := ""
2122

22-
if packageName == "" {
23-
dirName = fmt.Sprintf("datasource_%s", k)
23+
if packageName == "" {
24+
dirName = fmt.Sprintf("datasource_%s", k)
2425

25-
err := os.MkdirAll(filepath.Join(outputDir, dirName), os.ModePerm)
26-
if err != nil {
27-
return err
26+
err := os.MkdirAll(filepath.Join(outputDir, dirName), os.ModePerm)
27+
if err != nil {
28+
return err
29+
}
2830
}
29-
}
3031

31-
filename := fmt.Sprintf("%s_data_source_gen.go", k)
32+
filename := fmt.Sprintf("%s_data_source_gen.go", k)
3233

33-
f, err := os.Create(filepath.Join(outputDir, dirName, filename))
34-
if err != nil {
35-
return err
36-
}
37-
38-
_, err = f.Write(v)
39-
if err != nil {
40-
return err
41-
}
34+
// Combine all content first
35+
var allContent []byte
36+
allContent = append(allContent, v...)
37+
allContent = append(allContent, dataSourcesModels[k]...)
38+
allContent = append(allContent, customTypeValue[k]...)
39+
allContent = append(allContent, dataSourcesToFrom[k]...)
4240

43-
_, err = f.Write(dataSourcesModels[k])
44-
if err != nil {
45-
return err
46-
}
41+
// Deduplicate the combined content
42+
deduplicated, err := deduplicateGoCode(allContent)
43+
if err != nil {
44+
return err
45+
}
4746

48-
_, err = f.Write(customTypeValue[k])
49-
if err != nil {
50-
return err
51-
}
47+
f, err := os.Create(filepath.Join(outputDir, dirName, filename))
48+
if err != nil {
49+
return err
50+
}
51+
defer f.Close()
5252

53-
_, err = f.Write(dataSourcesToFrom[k])
54-
if err != nil {
55-
return err
53+
_, err = f.Write(deduplicated)
54+
if err != nil {
55+
return err
56+
}
5657
}
57-
}
5858

59-
return nil
59+
return nil
6060
}
6161

6262
// WriteResources uses the packageName to determine whether to create a directory and package per resource.
6363
// If packageName is an empty string, this indicates that the flag was not set, and the default behaviour is
6464
// then to create a package and directory per resource. If packageName is set then all generated code is
6565
// placed into the same directory and package.
6666
func WriteResources(resourcesSchema, resourcesModels, customTypeValue, resourcesToFrom map[string][]byte, outputDir, packageName string) error {
67-
for k, v := range resourcesSchema {
68-
dirName := ""
67+
for k, v := range resourcesSchema {
68+
dirName := ""
6969

70-
if packageName == "" {
71-
dirName = fmt.Sprintf("resource_%s", k)
70+
if packageName == "" {
71+
dirName = fmt.Sprintf("resource_%s", k)
7272

73-
err := os.MkdirAll(filepath.Join(outputDir, dirName), os.ModePerm)
74-
if err != nil {
75-
return err
73+
err := os.MkdirAll(filepath.Join(outputDir, dirName), os.ModePerm)
74+
if err != nil {
75+
return err
76+
}
7677
}
77-
}
7878

79-
filename := fmt.Sprintf("%s_resource_gen.go", k)
79+
filename := fmt.Sprintf("%s_resource_gen.go", k)
8080

81-
f, err := os.Create(filepath.Join(outputDir, dirName, filename))
82-
if err != nil {
83-
return err
84-
}
81+
// Combine all content first
82+
var allContent []byte
83+
allContent = append(allContent, v...)
84+
allContent = append(allContent, resourcesModels[k]...)
85+
allContent = append(allContent, customTypeValue[k]...)
86+
allContent = append(allContent, resourcesToFrom[k]...)
8587

86-
_, err = f.Write(v)
87-
if err != nil {
88-
return err
89-
}
90-
91-
_, err = f.Write(resourcesModels[k])
92-
if err != nil {
93-
return err
94-
}
88+
// Deduplicate the combined content
89+
deduplicated, err := deduplicateGoCode(allContent)
90+
if err != nil {
91+
return err
92+
}
9593

96-
_, err = f.Write(customTypeValue[k])
97-
if err != nil {
98-
return err
99-
}
94+
f, err := os.Create(filepath.Join(outputDir, dirName, filename))
95+
if err != nil {
96+
return err
97+
}
98+
defer f.Close()
10099

101-
_, err = f.Write(resourcesToFrom[k])
102-
if err != nil {
103-
return err
100+
_, err = f.Write(deduplicated)
101+
if err != nil {
102+
return err
103+
}
104104
}
105-
}
106105

107-
return nil
106+
return nil
108107
}
109108

110109
// WriteProviders uses the packageName to determine whether to create a directory and package for the provider.
@@ -173,3 +172,165 @@ func WriteBytes(outputFilePath string, outputBytes []byte, forceOverwrite bool)
173172

174173
return nil
175174
}
175+
176+
// deduplicateGoCode removes duplicate type and function declarations from Go source code
177+
func deduplicateGoCode(content []byte) ([]byte, error) {
178+
source := string(content)
179+
lines := strings.Split(source, "\n")
180+
181+
// Track seen declarations
182+
seen := make(map[string]bool)
183+
result := make([]string, 0, len(lines))
184+
185+
i := 0
186+
for i < len(lines) {
187+
line := lines[i]
188+
trimmedLine := strings.TrimSpace(line)
189+
190+
// Check for type declarations
191+
if strings.HasPrefix(trimmedLine, "type ") {
192+
// Extract type name
193+
fields := strings.Fields(trimmedLine)
194+
if len(fields) >= 2 {
195+
typeName := fields[1]
196+
key := "type:" + typeName
197+
198+
if seen[key] {
199+
// Skip this entire type declaration
200+
i = skipGoDeclaration(lines, i)
201+
continue
202+
} else {
203+
seen[key] = true
204+
}
205+
}
206+
}
207+
208+
// Check for function declarations
209+
if strings.HasPrefix(trimmedLine, "func ") {
210+
// Extract function name
211+
funcName := extractFunctionName(trimmedLine)
212+
if funcName != "" {
213+
key := "func:" + funcName
214+
215+
if seen[key] {
216+
// Skip this entire function declaration
217+
i = skipGoDeclaration(lines, i)
218+
continue
219+
} else {
220+
seen[key] = true
221+
}
222+
}
223+
}
224+
225+
// Check for var declarations
226+
if strings.HasPrefix(trimmedLine, "var _ ") {
227+
// Extract the type being checked
228+
parts := strings.Split(trimmedLine, "=")
229+
if len(parts) > 1 {
230+
rightPart := strings.TrimSpace(parts[1])
231+
// Extract type name from "TypeName{}" pattern
232+
if strings.HasSuffix(rightPart, "{}") {
233+
typeName := strings.TrimSpace(strings.TrimSuffix(rightPart, "{}"))
234+
key := "var:" + typeName
235+
236+
if seen[key] {
237+
// Skip this var declaration
238+
i++
239+
continue
240+
} else {
241+
seen[key] = true
242+
}
243+
}
244+
}
245+
}
246+
247+
result = append(result, line)
248+
i++
249+
}
250+
251+
return []byte(strings.Join(result, "\n")), nil
252+
}
253+
254+
// extractFunctionName extracts function name from a function declaration line
255+
func extractFunctionName(line string) string {
256+
// Handle both regular functions and methods
257+
// func Name(...) or func (receiver) Name(...)
258+
fields := strings.Fields(line)
259+
if len(fields) < 2 {
260+
return ""
261+
}
262+
263+
if strings.HasPrefix(fields[1], "(") {
264+
// Method with receiver: func (r Type) Name(...)
265+
if len(fields) >= 4 {
266+
// Extract receiver type and method name to create unique identifier
267+
receiverPart := fields[2] // This should be the type name like "PrincipalType)"
268+
funcName := fields[3]
269+
270+
// Clean up receiver type (remove closing parenthesis)
271+
receiverType := strings.TrimSuffix(receiverPart, ")")
272+
273+
// Extract just the function name (remove parameters)
274+
if idx := strings.Index(funcName, "("); idx > 0 {
275+
funcName = funcName[:idx]
276+
}
277+
278+
// Create unique key: ReceiverType.MethodName
279+
return receiverType + "." + funcName
280+
}
281+
} else {
282+
// Regular function: func Name(...)
283+
funcName := fields[1]
284+
if idx := strings.Index(funcName, "("); idx > 0 {
285+
return funcName[:idx]
286+
}
287+
}
288+
289+
return ""
290+
}
291+
292+
// skipGoDeclaration skips over a complete Go declaration (type, func, etc.)
293+
func skipGoDeclaration(lines []string, start int) int {
294+
if start >= len(lines) {
295+
return start
296+
}
297+
298+
line := strings.TrimSpace(lines[start])
299+
300+
// If it's a single-line declaration, just skip it
301+
if !strings.Contains(line, "{") {
302+
return start + 1
303+
}
304+
305+
// For multi-line declarations, count braces to find the end
306+
braceCount := 0
307+
i := start
308+
309+
for i < len(lines) {
310+
currentLine := lines[i]
311+
312+
// Count opening and closing braces
313+
for _, char := range currentLine {
314+
switch char {
315+
case '{':
316+
braceCount++
317+
case '}':
318+
braceCount--
319+
}
320+
}
321+
322+
i++
323+
324+
// If we've closed all braces, we're done
325+
if braceCount == 0 {
326+
break
327+
}
328+
}
329+
330+
// Skip any empty lines after the declaration
331+
for i < len(lines) && strings.TrimSpace(lines[i]) == "" {
332+
i++
333+
}
334+
335+
return i
336+
}

0 commit comments

Comments
 (0)