@@ -3,6 +3,8 @@ package compiler
33import (
44 "errors"
55 "fmt"
6+ "io/fs"
7+ "os"
68
79 "google.golang.org/protobuf/proto"
810
@@ -56,6 +58,10 @@ type config struct {
5658 objectTypePrefix * string
5759 allowedFlags * mapz.Set [string ]
5860 caveatTypeSet * caveattypes.TypeSet
61+
62+ // In an import context, this is the FS containing
63+ // the importing schema (as opposed to imported schemas)
64+ sourceFS fs.FS
5965}
6066
6167func SkipValidation () Option { return func (cfg * config ) { cfg .skipValidation = true } }
@@ -76,29 +82,50 @@ func CaveatTypeSet(cts *caveattypes.TypeSet) Option {
7682 return func (cfg * config ) { cfg .caveatTypeSet = cts }
7783}
7884
85+ // Config that supplies the root source folder for compilation. Required
86+ // for relative import syntax to work properly.
87+ func SourceFolder (sourceFolder string ) Option {
88+ return func (cfg * config ) { cfg .sourceFS = os .DirFS (sourceFolder ) }
89+ }
90+
91+ // Config that supplies the fs.FS for compilation as an alternative to
92+ // SourceFolder.
93+ func SourceFS (fsys fs.FS ) Option {
94+ return func (cfg * config ) { cfg .sourceFS = fsys }
95+ }
96+
7997const (
8098 expirationFlag = "expiration"
8199 selfFlag = "self"
82100 typeCheckingFlag = "typechecking"
83101 partialFlag = "partial"
102+ importFlag = "import"
84103)
85104
86- var allowedFlags = mapz .NewSet (expirationFlag , selfFlag , typeCheckingFlag , partialFlag )
105+ func allowedFlags () * mapz.Set [string ] {
106+ return mapz .NewSet (expirationFlag , selfFlag , typeCheckingFlag , partialFlag , importFlag )
107+ }
87108
88109func DisallowExpirationFlag () Option {
89110 return func (cfg * config ) {
90111 cfg .allowedFlags .Delete (expirationFlag )
91112 }
92113}
93114
115+ func DisallowImportFlag () Option {
116+ return func (cfg * config ) {
117+ cfg .allowedFlags .Delete (importFlag )
118+ }
119+ }
120+
94121type Option func (* config )
95122
96123type ObjectPrefixOption func (* config )
97124
98125// Compile compilers the input schema into a set of namespace definition protos.
99126func Compile (schema InputSchema , prefix ObjectPrefixOption , opts ... Option ) (* CompiledSchema , error ) {
100127 cfg := & config {
101- allowedFlags : allowedFlags ,
128+ allowedFlags : allowedFlags () ,
102129 }
103130
104131 prefix (cfg ) // required option
@@ -107,14 +134,36 @@ func Compile(schema InputSchema, prefix ObjectPrefixOption, opts ...Option) (*Co
107134 fn (cfg )
108135 }
109136
110- mapper := newPositionMapper (schema )
111- root := parser .Parse (createAstNode , schema .Source , schema .SchemaString ).(* dslNode )
112- errs := root .FindAll (dslshape .NodeTypeError )
113- if len (errs ) > 0 {
114- err := errorNodeToError (errs [0 ], mapper )
137+ root , mapper , err := parseSchema (schema )
138+ if err != nil {
115139 return nil , err
116140 }
117141
142+ present , err := validateImportPresence (cfg .allowedFlags .Has (importFlag ), root )
143+ if err != nil {
144+ // This condition should basically always be satisfied (we trigger errors off of the node),
145+ // but we're defensive here in case the implementation changes.
146+ var withNodeError withNodeError
147+ if errors .As (err , & withNodeError ) {
148+ return nil , toContextError (withNodeError .Error (), withNodeError .errorSourceCode , withNodeError .node , mapper )
149+ }
150+ return nil , err
151+ }
152+
153+ if present {
154+ // NOTE: import translation is done separately so that partial references
155+ // and definitions defined in separate files can correctly resolve.
156+ err = translateImports (importResolutionContext {
157+ globallyVisitedFiles : mapz .NewSet [string ](),
158+ locallyVisitedFiles : mapz .NewSet [string ](),
159+ sourceFS : cfg .sourceFS ,
160+ mapper : mapper ,
161+ }, root )
162+ if err != nil {
163+ return nil , err
164+ }
165+ }
166+
118167 initialCompiledPartials := make (map [string ][]* core.Relation )
119168 caveatTypeSet := caveattypes .TypeSetOrDefault (cfg .caveatTypeSet )
120169 compiled , err := translate (& translationContext {
@@ -141,6 +190,34 @@ func Compile(schema InputSchema, prefix ObjectPrefixOption, opts ...Option) (*Co
141190 return compiled , nil
142191}
143192
193+ func parseSchema (schema InputSchema ) (* dslNode , input.PositionMapper , error ) {
194+ mapper := newPositionMapper (schema )
195+ root := parser .Parse (createAstNode , schema .Source , schema .SchemaString ).(* dslNode )
196+ errs := root .FindAll (dslshape .NodeTypeError )
197+ if len (errs ) > 0 {
198+ err := errorNodeToError (errs [0 ], mapper )
199+ return nil , nil , err
200+ }
201+ return root , mapper , nil
202+ }
203+
204+ // validateImportPresence validates whether a given AST is valid based on whether
205+ // imports are allowed in the context. if they're present and disallowed it returns
206+ // a validation error; otherwise it returns the presence.
207+ func validateImportPresence (allowed bool , root * dslNode ) (present bool , err error ) {
208+ present = false
209+ for _ , topLevelNode := range root .GetChildren () {
210+ // Process import nodes; ignore the others
211+ if topLevelNode .GetType () == dslshape .NodeTypeImport {
212+ if ! allowed {
213+ return false , topLevelNode .Errorf ("import statements are not allowed in this context" )
214+ }
215+ present = true
216+ }
217+ }
218+ return present , nil
219+ }
220+
144221func errorNodeToError (node * dslNode , mapper input.PositionMapper ) error {
145222 if node .GetType () != dslshape .NodeTypeError {
146223 return errors .New ("given none error node" )
0 commit comments