diff --git a/internal/lsp/handlers.go b/internal/lsp/handlers.go index e37dda955..6f9d209de 100644 --- a/internal/lsp/handlers.go +++ b/internal/lsp/handlers.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "path/filepath" "strings" "github.com/jzelinskie/persistent" @@ -34,11 +35,6 @@ func (s *Server) textDocDiagnostic(ctx context.Context, r *jsonrpc2.Request) (Fu return FullDocumentDiagnosticReport{}, err } - log.Info(). - Str("uri", string(params.TextDocument.URI)). - Int("diagnostics", len(diagnostics)). - Msg("diagnostics complete") - return FullDocumentDiagnosticReport{ Kind: "full", Items: diagnostics, @@ -57,10 +53,11 @@ func (s *Server) computeDiagnostics(ctx context.Context, uri lsp.DocumentURI) ([ return &jsonrpc2.Error{Code: jsonrpc2.CodeInternalError, Message: "file not found"} } + overlayFS := newLSPOverlayFS(uriToSourceDir(uri), files) devCtx, devErrs, err := development.NewDevContext(ctx, &developerv1.RequestContext{ Schema: file.contents, Relationships: nil, - }) + }, development.WithSourceFS(overlayFS)) if err != nil { return err } @@ -101,7 +98,6 @@ func (s *Server) computeDiagnostics(ctx context.Context, uri lsp.DocumentURI) ([ return nil, err } - log.Info().Int("diagnostics", len(diagnostics)).Str("uri", string(uri)).Msg("computed diagnostics") return diagnostics, nil } @@ -171,15 +167,16 @@ func (s *Server) publishDiagnosticsIfNecessary(ctx context.Context, conn *jsonrp return nil } - log.Debug(). - Str("uri", string(uri)). - Msg("publishing diagnostics") - diagnostics, err := s.computeDiagnostics(ctx, uri) if err != nil { return fmt.Errorf("failed to compute diagnostics: %w", err) } + log.Info(). + Str("uri", string(uri)). + Int("diagnostics", len(diagnostics)). + Msg("publishing diagnostics") + return conn.Notify(ctx, "textDocument/publishDiagnostics", lsp.PublishDiagnosticsParams{ URI: uri, Diagnostics: diagnostics, @@ -197,7 +194,8 @@ func (s *Server) getCompiledContents(path lsp.DocumentURI, files *persistent.Map return compiled, nil } - justCompiled, derr, err := development.CompileSchema(file.contents) + overlayFS := newLSPOverlayFS(uriToSourceDir(path), files) + justCompiled, derr, err := development.CompileSchema(file.contents, development.WithSourceFS(overlayFS)) if err != nil { return nil, err } @@ -282,6 +280,63 @@ func (s *Server) textDocHover(_ context.Context, r *jsonrpc2.Request) (*Hover, e return hoverContents, nil } +func (s *Server) textDocDefinition(_ context.Context, r *jsonrpc2.Request) (*lsp.Location, error) { + params, err := unmarshalParams[lsp.TextDocumentPositionParams](r) + if err != nil { + return nil, err + } + + var location *lsp.Location + err = s.withFiles(func(files *persistent.Map[lsp.DocumentURI, trackedFile]) error { + compiled, err := s.getCompiledContents(params.TextDocument.URI, files) + if err != nil { + return err + } + + resolver, err := development.NewSchemaPositionMapper(compiled) + if err != nil { + return err + } + + position := input.Position{ + LineNumber: params.Position.Line, + ColumnPosition: params.Position.Character, + } + + resolved, err := resolver.ReferenceAtPosition(input.Source("schema"), position) + if err != nil { + return err + } + + if resolved == nil || resolved.TargetPosition == nil { + return nil + } + + // Determine the target file URI from TargetSource. + targetURI := params.TextDocument.URI + if resolved.TargetSource != nil && *resolved.TargetSource != "schema" { + sourceDir := uriToSourceDir(params.TextDocument.URI) + targetURI = lsp.DocumentURI("file://" + filepath.Join(sourceDir, string(*resolved.TargetSource))) + } + + nameStart := resolved.TargetPosition.ColumnPosition + resolved.TargetNamePositionOffset + location = &lsp.Location{ + URI: targetURI, + Range: lsp.Range{ + Start: lsp.Position{Line: resolved.TargetPosition.LineNumber, Character: nameStart}, + End: lsp.Position{Line: resolved.TargetPosition.LineNumber, Character: nameStart + len(resolved.Text)}, + }, + } + + return nil + }) + if err != nil { + return nil, err + } + + return location, nil +} + func (s *Server) textDocFormat(_ context.Context, r *jsonrpc2.Request) ([]lsp.TextEdit, error) { params, err := unmarshalParams[lsp.DocumentFormattingParams](r) if err != nil { @@ -335,8 +390,8 @@ func (s *Server) initialize(_ context.Context, r *jsonrpc2.Request) (any, error) return nil, err } - s.requestsDiagnostics = ip.Capabilities.Diagnostics.RefreshSupport - log.Debug(). + s.requestsDiagnostics = ip.Capabilities.Workspace.Diagnostics.RefreshSupport + log.Info(). Bool("requestsDiagnostics", s.requestsDiagnostics). Msg("initialize") @@ -353,6 +408,7 @@ func (s *Server) initialize(_ context.Context, r *jsonrpc2.Request) (any, error) DocumentFormattingProvider: true, DiagnosticProvider: &DiagnosticOptions{Identifier: "spicedb", InterFileDependencies: false, WorkspaceDiagnostics: false}, HoverProvider: true, + DefinitionProvider: true, }, }, nil } diff --git a/internal/lsp/lsp.go b/internal/lsp/lsp.go index f2de68801..f633e373f 100644 --- a/internal/lsp/lsp.go +++ b/internal/lsp/lsp.go @@ -96,6 +96,8 @@ func (s *Server) handle(ctx context.Context, conn *jsonrpc2.Conn, r *jsonrpc2.Re result, err = s.textDocFormat(ctx, r) case "textDocument/hover": result, err = s.textDocHover(ctx, r) + case "textDocument/definition": + result, err = s.textDocDefinition(ctx, r) default: log.Ctx(ctx).Warn(). Str("method", r.Method). diff --git a/internal/lsp/lsp_test.go b/internal/lsp/lsp_test.go index 8cfa1e90c..3ccf16a55 100644 --- a/internal/lsp/lsp_test.go +++ b/internal/lsp/lsp_test.go @@ -292,8 +292,10 @@ func TestDiagnosticsRefreshSupport(t *testing.T) { // Initialize with diagnostic refresh support enabled resp, serverState := sendAndReceive[lsp.InitializeResult](tester, "initialize", InitializeParams{ Capabilities: ClientCapabilities{ - Diagnostics: DiagnosticWorkspaceClientCapabilities{ - RefreshSupport: true, + Workspace: WorkspaceClientCapabilities{ + Diagnostics: DiagnosticWorkspaceClientCapabilities{ + RefreshSupport: true, + }, }, }, }) @@ -305,8 +307,10 @@ func TestDiagnosticsRefreshSupport(t *testing.T) { tester2 := newLSPTester(t) resp2, serverState2 := sendAndReceive[lsp.InitializeResult](tester2, "initialize", InitializeParams{ Capabilities: ClientCapabilities{ - Diagnostics: DiagnosticWorkspaceClientCapabilities{ - RefreshSupport: false, + Workspace: WorkspaceClientCapabilities{ + Diagnostics: DiagnosticWorkspaceClientCapabilities{ + RefreshSupport: false, + }, }, }, }) @@ -357,6 +361,219 @@ func TestUnmarshalParamsErrors(t *testing.T) { }().Code) } +func TestMultiFileNoDiagnostics(t *testing.T) { + tester := newLSPTester(t) + tester.initialize() + + tester.setFileContents("file:///testdir/users.zed", "definition user {}") + tester.setFileContents("file:///testdir/root.zed", `use import + +import "users.zed" + +definition resource { + relation viewer: user + permission view = viewer +} +`) + + resp, _ := sendAndReceive[FullDocumentDiagnosticReport](tester, "textDocument/diagnostic", + TextDocumentDiagnosticParams{ + TextDocument: TextDocument{URI: "file:///testdir/root.zed"}, + }) + require.Equal(t, "full", resp.Kind) + require.Empty(t, resp.Items) +} + +func TestMultiFileUndefinedDefinitionDiagnostics(t *testing.T) { + tester := newLSPTester(t) + tester.initialize() + + tester.setFileContents("file:///testdir/broken.zed", ` +definition resource { + relation viewer: organization + permission view = viewer +}`) + tester.setFileContents("file:///testdir/root.zed", `use import + +import "broken.zed" +`) + + resp, _ := sendAndReceive[FullDocumentDiagnosticReport](tester, "textDocument/diagnostic", + TextDocumentDiagnosticParams{ + TextDocument: TextDocument{URI: "file:///testdir/root.zed"}, + }) + require.Equal(t, "full", resp.Kind) + require.Len(t, resp.Items, 1) + require.Equal(t, lsp.Error, resp.Items[0].Severity) + t.Log(resp.Items[0].Message) + require.Contains(t, resp.Items[0].Message, "could not lookup definition `organization` for relation `viewer`: object definition `organization` not found") +} + +func TestMultiFileBrokenImportDiagnostics(t *testing.T) { + tester := newLSPTester(t) + tester.initialize() + + tester.setFileContents("file:///testdir/root.zed", `use import +import "unknown.zed" +`) + + resp, _ := sendAndReceive[FullDocumentDiagnosticReport](tester, "textDocument/diagnostic", + TextDocumentDiagnosticParams{ + TextDocument: TextDocument{URI: "file:///testdir/root.zed"}, + }) + require.Equal(t, "full", resp.Kind) + require.Len(t, resp.Items, 1) + require.Equal(t, lsp.Error, resp.Items[0].Severity) + require.Contains(t, resp.Items[0].Message, "failed to read import \"unknown.zed\": open unknown.zed: no such file or director") +} + +func TestDefinitionSameFileTypeReference(t *testing.T) { + tester := newLSPTester(t) + tester.initialize() + + tester.setFileContents("file:///test", `definition user {} + +definition resource { + relation viewer: user + permission view = viewer +} +`) + + // Click on "user" in "relation viewer: user" (line 3, character 18) + resp, _ := sendAndReceive[lsp.Location](tester, "textDocument/definition", + lsp.TextDocumentPositionParams{ + TextDocument: lsp.TextDocumentIdentifier{URI: "file:///test"}, + Position: lsp.Position{Line: 3, Character: 18}, + }) + require.Equal(t, lsp.DocumentURI("file:///test"), resp.URI) + require.Equal(t, 0, resp.Range.Start.Line) + require.Equal(t, len("definition "), resp.Range.Start.Character) +} + +func TestDefinitionSameFileRelationReference(t *testing.T) { + tester := newLSPTester(t) + tester.initialize() + + tester.setFileContents("file:///test", `definition user {} + +definition resource { + relation viewer: user + permission view = viewer +} +`) + + // Click on "viewer" in "permission view = viewer" (line 4, character 19) + resp, _ := sendAndReceive[lsp.Location](tester, "textDocument/definition", + lsp.TextDocumentPositionParams{ + TextDocument: lsp.TextDocumentIdentifier{URI: "file:///test"}, + Position: lsp.Position{Line: 4, Character: 19}, + }) + require.Equal(t, lsp.DocumentURI("file:///test"), resp.URI) + require.Equal(t, 3, resp.Range.Start.Line) + require.Equal(t, len("\trelation "), resp.Range.Start.Character) +} + +func TestDefinitionCrossFileTypeReference(t *testing.T) { + tester := newLSPTester(t) + tester.initialize() + + tester.setFileContents("file:///testdir/users.zed", "definition user {}") + tester.setFileContents("file:///testdir/root.zed", `use import + +import "users.zed" + +definition resource { + relation viewer: user + permission view = viewer +} +`) + + // Click on "user" in "relation viewer: user" (line 5, character 18) + // It should point to "users.zed" + resp, _ := sendAndReceive[lsp.Location](tester, "textDocument/definition", + lsp.TextDocumentPositionParams{ + TextDocument: lsp.TextDocumentIdentifier{URI: "file:///testdir/root.zed"}, + Position: lsp.Position{Line: 5, Character: 18}, + }) + require.Equal(t, lsp.DocumentURI("file:///testdir/users.zed"), resp.URI) + require.Equal(t, 0, resp.Range.Start.Line) + require.Equal(t, len("definition "), resp.Range.Start.Character) +} + +func TestDefinitionImportReference(t *testing.T) { + tester := newLSPTester(t) + tester.initialize() + + tester.setFileContents("file:///testdir/users.zed", "definition user {}") + tester.setFileContents("file:///testdir/root.zed", `use import + +import "users.zed" + +definition resource { + relation viewer: user + permission view = viewer +} +`) + + // Click on import "users.zed" (line 2, character 10) + // It should point on the very begginning of "users.zed" + resp, _ := sendAndReceive[lsp.Location](tester, "textDocument/definition", + lsp.TextDocumentPositionParams{ + TextDocument: lsp.TextDocumentIdentifier{URI: "file:///testdir/root.zed"}, + Position: lsp.Position{Line: 2, Character: 10}, + }) + require.Equal(t, lsp.DocumentURI("file:///testdir/users.zed"), resp.URI) + require.Equal(t, 0, resp.Range.Start.Line) + require.Equal(t, 0, resp.Range.Start.Character) +} + +func TestDefinitionCrossFileCaveatReference(t *testing.T) { + tester := newLSPTester(t) + tester.initialize() + + tester.setFileContents("file:///testdir/caveats.zed", `caveat some_caveat(some_param int) { + some_param < 100 +}`) + tester.setFileContents("file:///testdir/root.zed", `use import + +import "caveats.zed" + +definition user {} + +definition resource { + relation viewer: user with some_caveat +} +`) + + // Click on "some_caveat" in "relation viewer: user with some_caveat" (line 7, character 30) + // "\trelation viewer: user with some_caveat" + // 0 1 2 3 + // 0123456789012345678901234567890123456789 + resp, _ := sendAndReceive[lsp.Location](tester, "textDocument/definition", + lsp.TextDocumentPositionParams{ + TextDocument: lsp.TextDocumentIdentifier{URI: "file:///testdir/root.zed"}, + Position: lsp.Position{Line: 7, Character: 30}, + }) + require.Equal(t, lsp.DocumentURI("file:///testdir/caveats.zed"), resp.URI) + require.Equal(t, 0, resp.Range.Start.Line) + require.Equal(t, len("caveat "), resp.Range.Start.Character) +} + +func TestDefinitionNoReference(t *testing.T) { + tester := newLSPTester(t) + tester.initialize() + + tester.setFileContents("file:///test", "definition user {}") + + // Click on whitespace / keyword where no reference exists + resp, _ := sendAndReceive[*lsp.Location](tester, "textDocument/definition", + lsp.TextDocumentPositionParams{ + TextDocument: lsp.TextDocumentIdentifier{URI: "file:///test"}, + Position: lsp.Position{Line: 0, Character: 0}, + }) + require.Nil(t, resp) +} + func TestInvalidParams(t *testing.T) { err := invalidParams(errors.New("test error")) require.Equal(t, int64(jsonrpc2.CodeInvalidParams), err.Code) diff --git a/internal/lsp/lspdefs.go b/internal/lsp/lspdefs.go index 495075a22..cd8806055 100644 --- a/internal/lsp/lspdefs.go +++ b/internal/lsp/lspdefs.go @@ -14,6 +14,7 @@ type ServerCapabilities struct { DocumentFormattingProvider bool `json:"documentFormattingProvider,omitempty"` DiagnosticProvider *DiagnosticOptions `json:"diagnosticProvider,omitempty"` HoverProvider bool `json:"hoverProvider,omitempty"` + DefinitionProvider bool `json:"definitionProvider,omitempty"` } type DiagnosticOptions struct { @@ -53,6 +54,10 @@ type InitializeParams struct { } type ClientCapabilities struct { + Workspace WorkspaceClientCapabilities `json:"workspace"` +} + +type WorkspaceClientCapabilities struct { Diagnostics DiagnosticWorkspaceClientCapabilities `json:"diagnostics"` } diff --git a/internal/lsp/overlay.go b/internal/lsp/overlay.go new file mode 100644 index 000000000..ce4c10f99 --- /dev/null +++ b/internal/lsp/overlay.go @@ -0,0 +1,76 @@ +package lsp + +import ( + "io/fs" + "os" + "path/filepath" + "strings" + "time" + + "github.com/jzelinskie/persistent" + "github.com/sourcegraph/go-lsp" +) + +// lspOverlayFS is an fs.FS that serves open editor files from memory +// and falls back to disk for everything else. It is rooted at sourceDir. +type lspOverlayFS struct { + base fs.FS + sourceDir string + files *persistent.Map[lsp.DocumentURI, trackedFile] +} + +var _ fs.FS = &lspOverlayFS{} + +func newLSPOverlayFS(sourceDir string, files *persistent.Map[lsp.DocumentURI, trackedFile]) fs.FS { + return &lspOverlayFS{ + base: os.DirFS(sourceDir), + sourceDir: sourceDir, + files: files, + } +} + +func (o *lspOverlayFS) Open(name string) (fs.File, error) { + absPath := filepath.Join(o.sourceDir, filepath.FromSlash(name)) + uri := lsp.DocumentURI("file://" + absPath) + if file, ok := o.files.Get(uri); ok { + return newMemFile(name, file.contents), nil + } + return o.base.Open(name) +} + +// memFile implements fs.File for an in-memory string. +type memFile struct { + name string + content string + reader *strings.Reader +} + +var _ fs.File = &memFile{} + +func newMemFile(name, content string) *memFile { + return &memFile{name: name, content: content, reader: strings.NewReader(content)} +} + +func (m *memFile) Read(b []byte) (int, error) { return m.reader.Read(b) } +func (m *memFile) Close() error { return nil } +func (m *memFile) Stat() (fs.FileInfo, error) { + return memFileInfo{name: m.name, size: int64(len(m.content))}, nil +} + +type memFileInfo struct { + name string + size int64 +} + +func (i memFileInfo) Name() string { return filepath.Base(i.name) } +func (i memFileInfo) Size() int64 { return i.size } +func (i memFileInfo) Mode() fs.FileMode { return 0o444 } +func (i memFileInfo) ModTime() time.Time { return time.Time{} } +func (i memFileInfo) IsDir() bool { return false } +func (i memFileInfo) Sys() any { return nil } + +// uriToSourceDir extracts the containing directory from a file:// URI. +func uriToSourceDir(uri lsp.DocumentURI) string { + path := strings.TrimPrefix(string(uri), "file://") + return filepath.Dir(path) +} diff --git a/internal/lsp/testutil.go b/internal/lsp/testutil.go index 0fa09e81c..9850618de 100644 --- a/internal/lsp/testutil.go +++ b/internal/lsp/testutil.go @@ -48,8 +48,10 @@ type lspTester struct { func (lt *lspTester) initialize() { resp, serverState := sendAndReceive[lsp.InitializeResult](lt, "initialize", InitializeParams{ Capabilities: ClientCapabilities{ - Diagnostics: DiagnosticWorkspaceClientCapabilities{ - RefreshSupport: true, + Workspace: WorkspaceClientCapabilities{ + Diagnostics: DiagnosticWorkspaceClientCapabilities{ + RefreshSupport: true, + }, }, }, }) diff --git a/pkg/caveats/parameters.go b/pkg/caveats/parameters.go index 36047af4a..944d84edc 100644 --- a/pkg/caveats/parameters.go +++ b/pkg/caveats/parameters.go @@ -66,6 +66,9 @@ func ConvertContextToParameters( // ParameterTypeString returns the string form of the type reference. func ParameterTypeString(typeRef *core.CaveatTypeReference) string { + if typeRef == nil { + return "" + } var sb strings.Builder sb.WriteString(typeRef.TypeName) if len(typeRef.ChildTypes) > 0 { diff --git a/pkg/composableschemadsl/compiler/translator.go b/pkg/composableschemadsl/compiler/translator.go index a22ff6076..222c79c8a 100644 --- a/pkg/composableschemadsl/compiler/translator.go +++ b/pkg/composableschemadsl/compiler/translator.go @@ -826,7 +826,7 @@ func translateImports(itctx importResolutionContext, root *dslNode) error { // This is a new node provided by the translateImport parsedImportRoot, err := importFile(itctx.sourceFS, filePath) if err != nil { - return toContextError("failed to read import in schema file", "", topLevelNode, itctx.mapper) + return toContextError(err.Error(), "", topLevelNode, itctx.mapper) } // We recurse on that node to resolve any further imports diff --git a/pkg/development/devcontext.go b/pkg/development/devcontext.go index fc51e5c41..032ce68c2 100644 --- a/pkg/development/devcontext.go +++ b/pkg/development/devcontext.go @@ -52,7 +52,7 @@ type DevContext struct { // NewDevContext creates a new DevContext from the specified request context, parsing and populating // the datastore as needed. -func NewDevContext(ctx context.Context, requestContext *devinterface.RequestContext) (*DevContext, *devinterface.DeveloperErrors, error) { +func NewDevContext(ctx context.Context, requestContext *devinterface.RequestContext, opts ...CompileOption) (*DevContext, *devinterface.DeveloperErrors, error) { ds, err := memdb.NewMemdbDatastore(0, 0*time.Second, memdb.DisableGC) if err != nil { return nil, nil, err @@ -60,7 +60,7 @@ func NewDevContext(ctx context.Context, requestContext *devinterface.RequestCont dl := datalayer.NewDataLayer(ds) ctx = datalayer.ContextWithDataLayer(ctx, dl) - dctx, devErrs, nerr := newDevContextWithDataLayer(ctx, requestContext, dl) + dctx, devErrs, nerr := newDevContextWithDataLayer(ctx, requestContext, dl, opts...) if nerr != nil || devErrs != nil { // If any form of error occurred, immediately close the data layer derr := dl.Close() @@ -74,9 +74,9 @@ func NewDevContext(ctx context.Context, requestContext *devinterface.RequestCont return dctx, nil, nil } -func newDevContextWithDataLayer(ctx context.Context, requestContext *devinterface.RequestContext, dl datalayer.DataLayer) (*DevContext, *devinterface.DeveloperErrors, error) { +func newDevContextWithDataLayer(ctx context.Context, requestContext *devinterface.RequestContext, dl datalayer.DataLayer, opts ...CompileOption) (*DevContext, *devinterface.DeveloperErrors, error) { // Compile the schema and load its caveats and namespaces into the datastore. - compiled, devError, err := CompileSchema(requestContext.Schema) + compiled, devError, err := CompileSchema(requestContext.Schema, opts...) if err != nil { return nil, nil, err } diff --git a/pkg/development/schema.go b/pkg/development/schema.go index c94ae8281..04008ba20 100644 --- a/pkg/development/schema.go +++ b/pkg/development/schema.go @@ -2,6 +2,7 @@ package development import ( "errors" + "io/fs" "github.com/ccoveille/go-safecast/v2" @@ -11,14 +12,37 @@ import ( "github.com/authzed/spicedb/pkg/schemadsl/input" ) +// CompileOption configures schema compilation. +type CompileOption func(*compileConfig) + +type compileConfig struct { + fsys fs.FS +} + +// WithSourceFS enables import resolution using the given filesystem. +// The filesystem should be rooted at the directory containing the schema file. +func WithSourceFS(fsys fs.FS) CompileOption { + return func(cfg *compileConfig) { cfg.fsys = fsys } +} + // CompileSchema compiles a schema into its caveat and namespace definition(s), returning a developer // error if the schema could not be compiled. The non-developer error is returned only if an // internal errors occurred. -func CompileSchema(schema string) (*compiler.CompiledSchema, *devinterface.DeveloperError, error) { +func CompileSchema(schema string, opts ...CompileOption) (*compiler.CompiledSchema, *devinterface.DeveloperError, error) { + cfg := &compileConfig{} + for _, o := range opts { + o(cfg) + } + + var compilerOpts []compiler.Option + if cfg.fsys != nil { + compilerOpts = append(compilerOpts, compiler.SourceFS(cfg.fsys)) + } + compiled, err := compiler.Compile(compiler.InputSchema{ Source: input.Source("schema"), SchemaString: schema, - }, compiler.AllowUnprefixedObjectType()) + }, compiler.AllowUnprefixedObjectType(), compilerOpts...) var contextError compiler.WithContextError if errors.As(err, &contextError) { diff --git a/pkg/development/schema_position_mapper.go b/pkg/development/schema_position_mapper.go index bf0953c9b..d39bcf4b1 100644 --- a/pkg/development/schema_position_mapper.go +++ b/pkg/development/schema_position_mapper.go @@ -29,6 +29,8 @@ const ( ReferenceTypeRelation ReferenceTypePermission ReferenceTypeCaveatParameter + ReferenceTypeImport + ReferenceTypePartial ) // SchemaReference represents a reference to a schema node. @@ -86,6 +88,40 @@ func (r *SchemaPositionMapper) ReferenceAtPosition(source input.Source, position return nil, nil } + // Import reference. + if importPath, ok := r.importReferenceChain(nodeChain); ok { + importSource := input.Source(importPath) + return &SchemaReference{ + ReferenceType: ReferenceTypeImport, + Source: source, + Position: position, + TargetSource: &importSource, + TargetPosition: &input.Position{LineNumber: 0, ColumnPosition: 0}, + Text: importPath, + }, nil + } + + // Partial reference. + if partialName, ok := r.partialReferenceChain(nodeChain); ok { + line, col := r.schema.PartialNodePosition(partialName) + var targetPosition *input.Position + if line >= 0 && col >= 0 { + targetPosition = &input.Position{LineNumber: line, ColumnPosition: col} + } + + targetSource := r.resolveTargetSource(partialName, source) + return &SchemaReference{ + Source: source, + Position: position, + Text: partialName, + TargetSourceCode: "partial " + partialName, + ReferenceType: ReferenceTypePartial, + TargetSource: &targetSource, + TargetPosition: targetPosition, + TargetNamePositionOffset: len("partial "), + }, nil + } + relationReference := func(relation *core.Relation, def *schema.Definition) (*SchemaReference, error) { // NOTE: zeroes are fine here to mean "unknown" lineNumber, err := safecast.Convert[int](relation.SourcePosition.ZeroIndexedLineNumber) @@ -171,6 +207,7 @@ func (r *SchemaPositionMapper) ReferenceAtPosition(source input.Source, position targetSourceCode = fmt.Sprintf("%sdefinition %s {}", docComment, def.Name) } + targetSource := r.resolveTargetSource(def.Name, source) return &SchemaReference{ Source: source, Position: position, @@ -179,7 +216,7 @@ func (r *SchemaPositionMapper) ReferenceAtPosition(source input.Source, position ReferenceType: ReferenceTypeDefinition, ReferenceMarkdown: "definition " + def.Name, - TargetSource: &source, + TargetSource: &targetSource, TargetPosition: &defPosition, TargetSourceCode: targetSourceCode, TargetNamePositionOffset: len("definition "), @@ -216,6 +253,7 @@ func (r *SchemaPositionMapper) ReferenceAtPosition(source input.Source, position } caveatSourceCode.WriteString(") {\n\t// ...\n}") + targetSource := r.resolveTargetSource(caveatDef.Name, source) return &SchemaReference{ Source: source, Position: position, @@ -224,7 +262,7 @@ func (r *SchemaPositionMapper) ReferenceAtPosition(source input.Source, position ReferenceType: ReferenceTypeCaveat, ReferenceMarkdown: "caveat " + caveatDef.Name, - TargetSource: &source, + TargetSource: &targetSource, TargetPosition: &defPosition, TargetSourceCode: caveatSourceCode.String(), TargetNamePositionOffset: len("caveat "), @@ -256,6 +294,16 @@ func (r *SchemaPositionMapper) ReferenceAtPosition(source input.Source, position return nil, nil } +// resolveTargetSource returns the input.Source for the file where a definition or caveat +// is defined. For imported definitions, this will be the imported file path. For definitions +// in the root schema, this returns the fallback source. +func (r *SchemaPositionMapper) resolveTargetSource(name string, fallback input.Source) input.Source { + if defSource := r.schema.DefinitionNodeSource(name); defSource != "" { + return input.Source(defSource) + } + return fallback +} + func (r *SchemaPositionMapper) lookupCaveat(caveatName string) (*core.CaveatDefinition, bool) { c, err := r.typeSystem.GetCaveat(context.Background(), caveatName) if err != nil { @@ -423,3 +471,30 @@ func (r *SchemaPositionMapper) relationReferenceChain(nodeChain *compiler.NodeCh return r.lookupRelation(defName, relationName) } + +func (r *SchemaPositionMapper) importReferenceChain(nodeChain *compiler.NodeChain) (string, bool) { + importNode := nodeChain.FindNodeOfType(dslshape.NodeTypeImport) + if importNode == nil { + return "", false + } + + importPath, err := importNode.GetString(dslshape.NodeImportPredicatePath) + if err != nil { + return "", false + } + + return importPath, true +} + +func (r *SchemaPositionMapper) partialReferenceChain(nodeChain *compiler.NodeChain) (string, bool) { + if !nodeChain.HasHeadType(dslshape.NodeTypePartialReference) { + return "", false + } + + partialName, err := nodeChain.Head().GetString(dslshape.NodePartialReferencePredicateName) + if err != nil { + return "", false + } + + return partialName, true +} diff --git a/pkg/development/schema_position_mapper_test.go b/pkg/development/schema_position_mapper_test.go index 82e5515ec..9c3f5d463 100644 --- a/pkg/development/schema_position_mapper_test.go +++ b/pkg/development/schema_position_mapper_test.go @@ -2,6 +2,7 @@ package development import ( "testing" + "testing/fstest" "github.com/stretchr/testify/require" @@ -235,7 +236,7 @@ func TestSchemaPositionMapper(t *testing.T) { schema: `definition user {} caveat somecaveat(someparam int) { - someparam < 42 + someparam < 42 || someparam > 43 } definition resource { @@ -244,7 +245,7 @@ func TestSchemaPositionMapper(t *testing.T) { } `, line: 3, - column: 6, + column: 6, // TODO if you put 23, the mapper doesn't return anything expectedReference: &SchemaReference{ Source: input.Source("test"), Position: input.Position{LineNumber: 3, ColumnPosition: 6}, @@ -399,6 +400,32 @@ definition document { TargetNamePositionOffset: 9, }, }, + { + name: "reference to a partial", + schema: `use partial + + partial view_partial { + relation user: user + permission view = user + } + + definition secret { + ...view_partial + } + `, + line: 8, + column: 8, + expectedReference: &SchemaReference{ + Source: "test", + Position: input.Position{LineNumber: 8, ColumnPosition: 8}, + Text: "view_partial", + ReferenceType: ReferenceTypePartial, + TargetSource: &testSource, + TargetSourceCode: "partial view_partial", + TargetPosition: &input.Position{LineNumber: 2, ColumnPosition: 4}, + TargetNamePositionOffset: len("partial "), + }, + }, } for _, tc := range tcs { @@ -422,3 +449,96 @@ definition document { }) } } + +func TestSchemaPositionMapperComposableSchema(t *testing.T) { + rootSchema := `// this is a comment +use import +use partial +import "path/users.zed" +import "path/partials.zed" + +definition resource { + relation somerelation: user with is_raining + relation oops: group#member + ...secret +} +` + sourceFS := fstest.MapFS{ + "path/partials.zed": &fstest.MapFile{Data: []byte("use partial\npartial secret {\nrelation secret: user\n}")}, + "path/users.zed": &fstest.MapFile{Data: []byte("definition user {}\ncaveat is_raining(day string) {\nday == \"saturday\"\n}")}, + } + + rootSource := input.Source("path/root.zed") + compiled, err := compiler.Compile(compiler.InputSchema{ + Source: rootSource, + SchemaString: rootSchema, + }, compiler.AllowUnprefixedObjectType(), compiler.SourceFS(sourceFS)) + require.NoError(t, err) + + mapper, err := NewSchemaPositionMapper(compiled) + require.NoError(t, err) + + t.Run("cursor on the import path", func(t *testing.T) { + ref, err := mapper.ReferenceAtPosition(rootSource, input.Position{ + LineNumber: 3, + ColumnPosition: 5, + }) + require.NoError(t, err) + require.NotNil(t, ref) + require.Equal(t, ReferenceTypeImport, ref.ReferenceType) + require.Equal(t, input.Source("path/users.zed"), *ref.TargetSource) + require.Empty(t, ref.TargetSourceCode) + require.Equal(t, "path/users.zed", ref.Text) + require.Equal(t, rootSource, ref.Source) + require.Equal(t, input.Position{LineNumber: 3, ColumnPosition: 5}, ref.Position) + require.Equal(t, &input.Position{LineNumber: 0, ColumnPosition: 0}, ref.TargetPosition) + }) + + t.Run("cursor on the user definition", func(t *testing.T) { + ref, err := mapper.ReferenceAtPosition(rootSource, input.Position{ + LineNumber: 7, + ColumnPosition: 30, + }) + require.NoError(t, err) + require.NotNil(t, ref) + require.Equal(t, ReferenceTypeDefinition, ref.ReferenceType) + require.Equal(t, "definition user {}", ref.TargetSourceCode) + require.Equal(t, "user", ref.Text) + require.Equal(t, rootSource, ref.Source) + require.Equal(t, input.Source("path/users.zed"), *ref.TargetSource) + require.Equal(t, input.Position{LineNumber: 7, ColumnPosition: 30}, ref.Position) + require.Equal(t, &input.Position{LineNumber: 0, ColumnPosition: 0}, ref.TargetPosition) + }) + + t.Run("cursor on the caveat ref", func(t *testing.T) { + ref, err := mapper.ReferenceAtPosition(rootSource, input.Position{ + LineNumber: 7, + ColumnPosition: 43, + }) + require.NoError(t, err) + require.NotNil(t, ref) + require.Equal(t, ReferenceTypeCaveat, ref.ReferenceType) + require.Equal(t, "caveat is_raining(day string) {\n\t// ...\n}", ref.TargetSourceCode) + require.Equal(t, "is_raining", ref.Text) + require.Equal(t, rootSource, ref.Source) + require.Equal(t, input.Source("path/users.zed"), *ref.TargetSource) + require.Equal(t, input.Position{LineNumber: 7, ColumnPosition: 43}, ref.Position) + require.Equal(t, &input.Position{LineNumber: 1, ColumnPosition: 0}, ref.TargetPosition) + }) + + t.Run("cursor on the partial ref", func(t *testing.T) { + ref, err := mapper.ReferenceAtPosition(rootSource, input.Position{ + LineNumber: 9, + ColumnPosition: 5, + }) + require.NoError(t, err) + require.NotNil(t, ref) + require.Equal(t, ReferenceTypePartial, ref.ReferenceType) + require.Equal(t, "partial secret", ref.TargetSourceCode) + require.Equal(t, "secret", ref.Text) + require.Equal(t, input.Source("path/partials.zed"), *ref.TargetSource) + require.Equal(t, rootSource, ref.Source) + require.Equal(t, input.Position{LineNumber: 9, ColumnPosition: 5}, ref.Position) + require.Equal(t, &input.Position{LineNumber: 1, ColumnPosition: 0}, ref.TargetPosition) + }) +} diff --git a/pkg/schemadsl/compiler/development.go b/pkg/schemadsl/compiler/development.go index 46a04173a..046e78738 100644 --- a/pkg/schemadsl/compiler/development.go +++ b/pkg/schemadsl/compiler/development.go @@ -55,6 +55,69 @@ func (nc *NodeChain) String() string { return out.String() } +// DefinitionNodeSource returns the input source for the AST node defining the given name. +// For definitions compiled from imports, this will be the imported file path (e.g. "users.zed"). +// For definitions in the root schema, this will be the root source (e.g. "schema"). +// Returns empty string if not found. +func (cs *CompiledSchema) DefinitionNodeSource(name string) string { + for _, child := range cs.rootNode.GetChildren() { + nodeType := child.GetType() + var predicateName string + switch nodeType { + case dslshape.NodeTypeDefinition: + predicateName = dslshape.NodeDefinitionPredicateName + case dslshape.NodeTypeCaveatDefinition: + predicateName = dslshape.NodeCaveatDefinitionPredicateName + case dslshape.NodeTypePartial: + predicateName = dslshape.NodePartialPredicateName + default: + continue + } + + defName, err := child.GetString(predicateName) + if err != nil { + continue + } + + if defName == name { + source, err := child.GetString(dslshape.NodePredicateSource) + if err != nil { + return "" + } + return source + } + } + return "" +} + +// PartialNodePosition returns the start rune position and source of the partial +// definition with the given name. Returns (-1, -1, "") if not found. +func (cs *CompiledSchema) PartialNodePosition(name string) (int, int) { + for _, child := range cs.rootNode.GetChildren() { + if child.GetType() != dslshape.NodeTypePartial { + continue + } + + partialName, err := child.GetString(dslshape.NodePartialPredicateName) + if err != nil || partialName != name { + continue + } + + sourceRange, err := child.Range(cs.mapper) + if err != nil { + return -1, -1 + } + + line, col, err := sourceRange.Start().LineAndColumn() + if err != nil { + return -1, -1 + } + + return line, col + } + return -1, -1 +} + // PositionToAstNodeChain returns the AST node, and its parents (if any), found at the given position in the source, if any. func PositionToAstNodeChain(schema *CompiledSchema, source input.Source, position input.Position) (*NodeChain, error) { rootSource, err := schema.rootNode.GetString(dslshape.NodePredicateSource) @@ -73,7 +136,7 @@ func PositionToAstNodeChain(schema *CompiledSchema, source input.Source, positio } // Find the node at the rune position. - found, err := runePositionToAstNodeChain(schema.rootNode, runePosition) + found, err := runePositionToAstNodeChain(schema.rootNode, runePosition, rootSource) if err != nil { return nil, err } @@ -85,11 +148,18 @@ func PositionToAstNodeChain(schema *CompiledSchema, source input.Source, positio return &NodeChain{nodes: found, runePosition: runePosition}, nil } -func runePositionToAstNodeChain(node *dslNode, runePosition int) ([]DSLNode, error) { +func runePositionToAstNodeChain(node *dslNode, runePosition int, rootSource string) ([]DSLNode, error) { if !node.Has(dslshape.NodePredicateStartRune) { return nil, nil } + // Skip nodes from imported files whose rune positions may overlap with the root file. + if nodeSource, err := node.GetString(dslshape.NodePredicateSource); err == nil { + if nodeSource != rootSource { + return nil, nil + } + } + startRune, err := node.GetInt(dslshape.NodePredicateStartRune) if err != nil { return nil, err @@ -105,7 +175,7 @@ func runePositionToAstNodeChain(node *dslNode, runePosition int) ([]DSLNode, err } for _, child := range node.AllSubNodes() { - childChain, err := runePositionToAstNodeChain(child, runePosition) + childChain, err := runePositionToAstNodeChain(child, runePosition, rootSource) if err != nil { return nil, err } diff --git a/pkg/schemadsl/compiler/importer.go b/pkg/schemadsl/compiler/importer.go index 77eb6e2b8..d16a71d53 100644 --- a/pkg/schemadsl/compiler/importer.go +++ b/pkg/schemadsl/compiler/importer.go @@ -15,18 +15,19 @@ type CircularImportError struct { filePath string } -func importFile(fsys fs.FS, filePath string) (*dslNode, error) { +func importFile(fsys fs.FS, filePath string) (*dslNode, string, error) { schemaBytes, err := fs.ReadFile(fsys, filePath) if err != nil { - return nil, fmt.Errorf("failed to read import in schema file %q: %w", filePath, err) + return nil, "", fmt.Errorf("failed to read import %q: %w", filePath, err) } logging.Trace().Str("schema", string(schemaBytes)).Str("file", filePath).Msg("read schema from file") + content := string(schemaBytes) parsedSchema, _, err := parseSchema(InputSchema{ Source: input.Source(filePath), - SchemaString: string(schemaBytes), + SchemaString: content, }) - return parsedSchema, err + return parsedSchema, content, err } // Take a filepath and ensure that it's local to the current context. diff --git a/pkg/schemadsl/compiler/positionmapper.go b/pkg/schemadsl/compiler/positionmapper.go index aa33c43b2..f9f098f16 100644 --- a/pkg/schemadsl/compiler/positionmapper.go +++ b/pkg/schemadsl/compiler/positionmapper.go @@ -6,27 +6,46 @@ import ( "github.com/authzed/spicedb/pkg/schemadsl/input" ) +// positionMapper converts rune positions to line/column positions and vice versa. +// It is multi-file aware: the root schema file's mapper is created at construction time, +// and imported files are registered via RegisterImportedFile during import resolution. +// This allows correct position mapping for AST nodes from imported files, whose rune +// positions are relative to their own file content rather than the root schema. +// +// This is distinct from development.SchemaPositionMapper, which is a higher-level +// construct that resolves semantic references (e.g., type references, relation references) +// to their target definitions using the compiled schema and type system. type positionMapper struct { - schema InputSchema - mapper input.SourcePositionMapper + mappers map[input.Source]input.SourcePositionMapper + contents map[input.Source]string } -func newPositionMapper(schema InputSchema) input.PositionMapper { - return &positionMapper{ - schema: schema, - mapper: input.CreateSourcePositionMapper([]byte(schema.SchemaString)), +func newPositionMapper(schema InputSchema) *positionMapper { + pm := &positionMapper{ + mappers: make(map[input.Source]input.SourcePositionMapper), + contents: make(map[input.Source]string), } + pm.mappers[schema.Source] = input.CreateSourcePositionMapper([]byte(schema.SchemaString)) + pm.contents[schema.Source] = schema.SchemaString + return pm } -func (pm *positionMapper) RunePositionToLineAndCol(runePosition int, _ input.Source) (int, int, error) { - return pm.mapper.RunePositionToLineAndCol(runePosition) +// RegisterImportedFile registers an imported file's content so that rune positions +// from that file can be correctly mapped to line/col. +func (pm *positionMapper) RegisterImportedFile(source input.Source, content string) { + pm.mappers[source] = input.CreateSourcePositionMapper([]byte(content)) + pm.contents[source] = content } -func (pm *positionMapper) LineAndColToRunePosition(lineNumber int, colPosition int, _ input.Source) (int, error) { - return pm.mapper.LineAndColToRunePosition(lineNumber, colPosition) +func (pm *positionMapper) RunePositionToLineAndCol(runePosition int, source input.Source) (int, int, error) { + return pm.mappers[source].RunePositionToLineAndCol(runePosition) } -func (pm *positionMapper) TextForLine(lineNumber int, _ input.Source) (string, error) { - lines := strings.Split(pm.schema.SchemaString, "\n") +func (pm *positionMapper) LineAndColToRunePosition(lineNumber int, colPosition int, source input.Source) (int, error) { + return pm.mappers[source].LineAndColToRunePosition(lineNumber, colPosition) +} + +func (pm *positionMapper) TextForLine(lineNumber int, source input.Source) (string, error) { + lines := strings.Split(pm.contents[source], "\n") return lines[lineNumber], nil } diff --git a/pkg/schemadsl/compiler/translator.go b/pkg/schemadsl/compiler/translator.go index c70eaeb1e..afcf02bed 100644 --- a/pkg/schemadsl/compiler/translator.go +++ b/pkg/schemadsl/compiler/translator.go @@ -840,6 +840,9 @@ func translateImports(itctx importResolutionContext, root *dslNode) error { if err := validateFilepath(importPath); err != nil { return err } + if itctx.sourceFS == nil { + return errors.New("import statement found but no source filesystem was configured for compilation") + } filePath := filepath.Join(itctx.sourcePrefix, importPath) newSourcePrefix := filepath.Dir(filePath) @@ -866,11 +869,14 @@ func translateImports(itctx importResolutionContext, root *dslNode) error { // Do the actual import here // This is a new node provided by the translateImport - parsedImportRoot, err := importFile(itctx.sourceFS, filePath) + parsedImportRoot, importedContent, err := importFile(itctx.sourceFS, filePath) if err != nil { - return toContextError("failed to read import in schema file", "", topLevelNode, itctx.mapper) + return toContextError(err.Error(), "", topLevelNode, itctx.mapper) } + // Register the imported file so position mapping works correctly for cross-file references. + itctx.mapper.RegisterImportedFile(input.Source(filePath), importedContent) + // We recurse on that node to resolve any further imports err = translateImports(importResolutionContext{ sourceFS: itctx.sourceFS, diff --git a/pkg/schemadsl/input/inputsource.go b/pkg/schemadsl/input/inputsource.go index 38433b091..c317f9883 100644 --- a/pkg/schemadsl/input/inputsource.go +++ b/pkg/schemadsl/input/inputsource.go @@ -58,6 +58,9 @@ type PositionMapper interface { // TextForLine returns the text for the specified line number. TextForLine(lineNumber int, path Source) (string, error) + + // RegisterImportedFile registers the imported file within the mapper for cross-file references to work correctly. + RegisterImportedFile(source Source, content string) } // SourceRange represents a range inside a source file.