Skip to content

Commit d586c11

Browse files
Connecting up the augmentor
1 parent df8064f commit d586c11

File tree

2 files changed

+38
-2
lines changed

2 files changed

+38
-2
lines changed

compiler/astutil/astutil.go

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -531,8 +531,16 @@ func ConcatenateFiles(file *ast.File, tails ...*ast.File) error {
531531
path := imp.Path.Value
532532
if oldImp, ok := imports[path]; ok {
533533
// Import is in both files so check if the import name is not different.
534-
if oldName, newName := ImportName(oldImp), ImportName(imp); oldName != newName {
535-
return fmt.Errorf("import from of %s can not be concatenated with different name: %q != %q", path, oldName, newName)
534+
oldName, newName := ImportName(oldImp), ImportName(imp)
535+
if oldName != newName {
536+
if len(oldName) == 0 {
537+
// Update the import name to the new name.
538+
// This assumes the import name was `_` and
539+
// could cause problems if it was `.`
540+
oldImp.Name = imp.Name
541+
} else if len(newName) != 0 {
542+
return fmt.Errorf("import from of %s can not be concatenated with different name: %q != %q", path, oldName, newName)
543+
}
536544
}
537545
continue
538546
}

compiler/astutil/astutil_test.go

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1076,6 +1076,34 @@ func TestConcatenateFiles(t *testing.T) {
10761076
import f2 "fmt"
10771077
func bar() { f2.Println("bar") }`,
10781078
expErr: `import from of "fmt" can not be concatenated with different name: "f1" != "f2"`,
1079+
}, {
1080+
name: `import mismatch with old being blank`,
1081+
srcHead: `package testpackage
1082+
import _ "unsafe"
1083+
//go:linkname foo runtime.foo
1084+
func bar()`,
1085+
srcTail: `package testpackage
1086+
import "unsafe"
1087+
func foo() unsafe.Pointer { return nil }`,
1088+
want: `package testpackage
1089+
import "unsafe"
1090+
//go:linkname foo runtime.foo
1091+
func bar()
1092+
func foo() unsafe.Pointer { return nil }`,
1093+
}, {
1094+
name: `import mismatch with new being blank`,
1095+
srcHead: `package testpackage
1096+
import "unsafe"
1097+
func foo() unsafe.Pointer { return nil }`,
1098+
srcTail: `package testpackage
1099+
import _ "unsafe"
1100+
//go:linkname foo runtime.foo
1101+
func bar()`,
1102+
want: `package testpackage
1103+
import "unsafe"
1104+
func foo() unsafe.Pointer { return nil }
1105+
//go:linkname foo runtime.foo
1106+
func bar()`,
10791107
},
10801108
}
10811109

0 commit comments

Comments
 (0)