|
| 1 | +package graphql |
| 2 | + |
| 3 | +import ( |
| 4 | + "context" |
| 5 | + "fmt" |
| 6 | + |
| 7 | + "github.com/graphql-go/graphql/gqlerrors" |
| 8 | + "github.com/graphql-go/graphql/language/ast" |
| 9 | +) |
| 10 | + |
| 11 | +type ResultIteratorFn func(count int64, result *Result, doneFunc func()) |
| 12 | + |
| 13 | +type ResultIterator struct { |
| 14 | + count int64 |
| 15 | + ctx context.Context |
| 16 | + ch chan *Result |
| 17 | + cancelFunc context.CancelFunc |
| 18 | + cancelled bool |
| 19 | + handlers []ResultIteratorFn |
| 20 | +} |
| 21 | + |
| 22 | +func NewResultIterator(ctx context.Context, ch chan *Result) *ResultIterator { |
| 23 | + if ctx == nil { |
| 24 | + ctx = context.Background() |
| 25 | + } |
| 26 | + |
| 27 | + cctx, cancelFunc := context.WithCancel(ctx) |
| 28 | + iterator := &ResultIterator{ |
| 29 | + count: 0, |
| 30 | + ctx: cctx, |
| 31 | + ch: ch, |
| 32 | + cancelFunc: cancelFunc, |
| 33 | + cancelled: false, |
| 34 | + handlers: []ResultIteratorFn{}, |
| 35 | + } |
| 36 | + |
| 37 | + go func() { |
| 38 | + for { |
| 39 | + select { |
| 40 | + case <-iterator.ctx.Done(): |
| 41 | + return |
| 42 | + case res := <-iterator.ch: |
| 43 | + if iterator.cancelled { |
| 44 | + return |
| 45 | + } |
| 46 | + iterator.count += 1 |
| 47 | + for _, handler := range iterator.handlers { |
| 48 | + handler(iterator.count, res, iterator.Done) |
| 49 | + } |
| 50 | + } |
| 51 | + } |
| 52 | + }() |
| 53 | + |
| 54 | + return iterator |
| 55 | +} |
| 56 | + |
| 57 | +func (c *ResultIterator) ForEach(handler ResultIteratorFn) { |
| 58 | + c.handlers = append(c.handlers, handler) |
| 59 | +} |
| 60 | + |
| 61 | +func (c *ResultIterator) Done() { |
| 62 | + c.cancelled = true |
| 63 | + c.cancelFunc() |
| 64 | +} |
| 65 | + |
| 66 | +type SubscribeParams struct { |
| 67 | + Schema Schema |
| 68 | + Document *ast.Document |
| 69 | + RootValue interface{} |
| 70 | + ContextValue context.Context |
| 71 | + VariableValues map[string]interface{} |
| 72 | + OperationName string |
| 73 | + FieldResolver FieldResolveFn |
| 74 | + FieldSubscriber FieldResolveFn |
| 75 | +} |
| 76 | + |
| 77 | +// Subscribe performs a subscribe operation |
| 78 | +func Subscribe(p SubscribeParams) *ResultIterator { |
| 79 | + resultChannel := make(chan *Result) |
| 80 | + // Use background context if no context was provided |
| 81 | + ctx := p.ContextValue |
| 82 | + if ctx == nil { |
| 83 | + ctx = context.Background() |
| 84 | + } |
| 85 | + |
| 86 | + var mapSourceToResponse = func(payload interface{}) *Result { |
| 87 | + return Execute(ExecuteParams{ |
| 88 | + Schema: p.Schema, |
| 89 | + Root: payload, |
| 90 | + AST: p.Document, |
| 91 | + OperationName: p.OperationName, |
| 92 | + Args: p.VariableValues, |
| 93 | + Context: p.ContextValue, |
| 94 | + }) |
| 95 | + } |
| 96 | + |
| 97 | + go func() { |
| 98 | + |
| 99 | + result := &Result{} |
| 100 | + defer func() { |
| 101 | + if err := recover(); err != nil { |
| 102 | + result.Errors = append(result.Errors, gqlerrors.FormatError(err.(error))) |
| 103 | + } |
| 104 | + resultChannel <- result |
| 105 | + }() |
| 106 | + |
| 107 | + exeContext, err := buildExecutionContext(buildExecutionCtxParams{ |
| 108 | + Schema: p.Schema, |
| 109 | + Root: p.RootValue, |
| 110 | + AST: p.Document, |
| 111 | + OperationName: p.OperationName, |
| 112 | + Args: p.VariableValues, |
| 113 | + Result: result, |
| 114 | + Context: p.ContextValue, |
| 115 | + }) |
| 116 | + |
| 117 | + if err != nil { |
| 118 | + result.Errors = append(result.Errors, gqlerrors.FormatError(err.(error))) |
| 119 | + resultChannel <- result |
| 120 | + return |
| 121 | + } |
| 122 | + |
| 123 | + operationType, err := getOperationRootType(p.Schema, exeContext.Operation) |
| 124 | + if err != nil { |
| 125 | + result.Errors = append(result.Errors, gqlerrors.FormatError(err.(error))) |
| 126 | + resultChannel <- result |
| 127 | + return |
| 128 | + } |
| 129 | + |
| 130 | + fields := collectFields(collectFieldsParams{ |
| 131 | + ExeContext: exeContext, |
| 132 | + RuntimeType: operationType, |
| 133 | + SelectionSet: exeContext.Operation.GetSelectionSet(), |
| 134 | + }) |
| 135 | + |
| 136 | + responseNames := []string{} |
| 137 | + for name := range fields { |
| 138 | + responseNames = append(responseNames, name) |
| 139 | + } |
| 140 | + responseName := responseNames[0] |
| 141 | + fieldNodes := fields[responseName] |
| 142 | + fieldNode := fieldNodes[0] |
| 143 | + fieldName := fieldNode.Name.Value |
| 144 | + fieldDef := getFieldDef(p.Schema, operationType, fieldName) |
| 145 | + |
| 146 | + if fieldDef == nil { |
| 147 | + err := fmt.Errorf("the subscription field %q is not defined", fieldName) |
| 148 | + result.Errors = append(result.Errors, gqlerrors.FormatError(err.(error))) |
| 149 | + resultChannel <- result |
| 150 | + return |
| 151 | + } |
| 152 | + |
| 153 | + resolveFn := p.FieldSubscriber |
| 154 | + if resolveFn == nil { |
| 155 | + resolveFn = DefaultResolveFn |
| 156 | + } |
| 157 | + if fieldDef.Subscribe != nil { |
| 158 | + resolveFn = fieldDef.Subscribe |
| 159 | + } |
| 160 | + fieldPath := &ResponsePath{ |
| 161 | + Key: responseName, |
| 162 | + } |
| 163 | + |
| 164 | + args := getArgumentValues(fieldDef.Args, fieldNode.Arguments, exeContext.VariableValues) |
| 165 | + info := ResolveInfo{ |
| 166 | + FieldName: fieldName, |
| 167 | + FieldASTs: fieldNodes, |
| 168 | + Path: fieldPath, |
| 169 | + ReturnType: fieldDef.Type, |
| 170 | + ParentType: operationType, |
| 171 | + Schema: p.Schema, |
| 172 | + Fragments: exeContext.Fragments, |
| 173 | + RootValue: exeContext.Root, |
| 174 | + Operation: exeContext.Operation, |
| 175 | + VariableValues: exeContext.VariableValues, |
| 176 | + } |
| 177 | + |
| 178 | + fieldResult, err := resolveFn(ResolveParams{ |
| 179 | + Source: p.RootValue, |
| 180 | + Args: args, |
| 181 | + Info: info, |
| 182 | + Context: exeContext.Context, |
| 183 | + }) |
| 184 | + if err != nil { |
| 185 | + result.Errors = append(result.Errors, gqlerrors.FormatError(err.(error))) |
| 186 | + resultChannel <- result |
| 187 | + return |
| 188 | + } |
| 189 | + |
| 190 | + if fieldResult == nil { |
| 191 | + err := fmt.Errorf("no field result") |
| 192 | + result.Errors = append(result.Errors, gqlerrors.FormatError(err.(error))) |
| 193 | + resultChannel <- result |
| 194 | + return |
| 195 | + } |
| 196 | + |
| 197 | + switch fieldResult.(type) { |
| 198 | + case chan interface{}: |
| 199 | + for { |
| 200 | + select { |
| 201 | + case <-ctx.Done(): |
| 202 | + fmt.Printf("done context called") |
| 203 | + return |
| 204 | + case res := <-fieldResult.(chan interface{}): |
| 205 | + |
| 206 | + resultChannel <- mapSourceToResponse(res) |
| 207 | + } |
| 208 | + } |
| 209 | + default: |
| 210 | + resultChannel <- mapSourceToResponse(fieldResult) |
| 211 | + return |
| 212 | + } |
| 213 | + }() |
| 214 | + |
| 215 | + // return a result iterator |
| 216 | + return NewResultIterator(p.ContextValue, resultChannel) |
| 217 | +} |
0 commit comments