@@ -2,17 +2,19 @@ package neo4j
22
33import (
44 "bytes"
5+ "context"
56 "fmt"
6- "golang.org/x/mod/semver"
77 "io"
88 neturl "net/url"
99 "strconv"
1010 "sync/atomic"
1111
12+ "golang.org/x/mod/semver"
13+
1214 "github.com/golang-migrate/migrate/v4/database"
1315 "github.com/golang-migrate/migrate/v4/database/multistmt"
1416 "github.com/hashicorp/go-multierror"
15- "github.com/neo4j/neo4j-go-driver/v4 /neo4j"
17+ "github.com/neo4j/neo4j-go-driver/v5 /neo4j"
1618)
1719
1820func init () {
@@ -38,14 +40,14 @@ type Config struct {
3840}
3941
4042type Neo4j struct {
41- driver neo4j.Driver
43+ driver neo4j.DriverWithContext
4244 lock uint32
4345
4446 // Open and WithInstance need to guarantee that config is never nil
4547 config * Config
4648}
4749
48- func WithInstance (driver neo4j.Driver , config * Config ) (database.Driver , error ) {
50+ func WithInstance (driver neo4j.DriverWithContext , config * Config ) (database.Driver , error ) {
4951 if config == nil {
5052 return nil , ErrNilConfig
5153 }
@@ -70,31 +72,16 @@ func (n *Neo4j) Open(url string) (database.Driver, error) {
7072 password , _ := uri .User .Password ()
7173 authToken := neo4j .BasicAuth (uri .User .Username (), password , "" )
7274 uri .User = nil
73- uri .Scheme = "bolt"
7475 msQuery := uri .Query ().Get ("x-multi-statement" )
7576
76- // Whether to turn on/off TLS encryption.
77- tlsEncrypted := uri .Query ().Get ("x-tls-encrypted" )
7877 multi := false
79- encrypted := false
8078 if msQuery != "" {
8179 multi , err = strconv .ParseBool (uri .Query ().Get ("x-multi-statement" ))
8280 if err != nil {
8381 return nil , err
8482 }
8583 }
8684
87- if tlsEncrypted != "" {
88- encrypted , err = strconv .ParseBool (tlsEncrypted )
89- if err != nil {
90- return nil , err
91- }
92- }
93-
94- if encrypted {
95- uri .Scheme += "+s"
96- }
97-
9885 multiStatementMaxSize := DefaultMultiStatementMaxSize
9986 if s := uri .Query ().Get ("x-multi-statement-max-size" ); s != "" {
10087 multiStatementMaxSize , err = strconv .Atoi (s )
@@ -105,11 +92,15 @@ func (n *Neo4j) Open(url string) (database.Driver, error) {
10592
10693 uri .RawQuery = ""
10794
108- driver , err := neo4j .NewDriver (uri .String (), authToken , func ( config * neo4j. Config ) {} )
95+ driver , err := neo4j .NewDriverWithContext (uri .String (), authToken )
10996 if err != nil {
11097 return nil , err
11198 }
11299
100+ if err = driver .VerifyConnectivity (context .Background ()); err != nil {
101+ return nil , err
102+ }
103+
113104 return WithInstance (driver , & Config {
114105 MigrationsLabel : DefaultMigrationsLabel ,
115106 MultiStatement : multi ,
@@ -118,7 +109,7 @@ func (n *Neo4j) Open(url string) (database.Driver, error) {
118109}
119110
120111func (n * Neo4j ) Close () error {
121- return n .driver .Close ()
112+ return n .driver .Close (context . Background () )
122113}
123114
124115// local locking in order to pass tests, Neo doesn't support database locking
@@ -138,60 +129,71 @@ func (n *Neo4j) Unlock() error {
138129}
139130
140131func (n * Neo4j ) Run (migration io.Reader ) (err error ) {
141- session := n .driver .NewSession (neo4j.SessionConfig {AccessMode : neo4j .AccessModeWrite })
132+ ctx := context .Background ()
133+ session := n .driver .NewSession (ctx , neo4j.SessionConfig {AccessMode : neo4j .AccessModeWrite })
142134 defer func () {
143- if cerr := session .Close (); cerr != nil {
135+ if cerr := session .Close (ctx ); cerr != nil {
144136 err = multierror .Append (err , cerr )
145137 }
146138 }()
147139
148140 if n .config .MultiStatement {
149- _ , err = session .WriteTransaction (func (transaction neo4j.Transaction ) (interface {}, error ) {
150- var stmtRunErr error
151- if err := multistmt .Parse (migration , StatementSeparator , n .config .MultiStatementMaxSize , func (stmt []byte ) bool {
152- trimStmt := bytes .TrimSpace (stmt )
153- if len (trimStmt ) == 0 {
154- return true
155- }
156- trimStmt = bytes .TrimSuffix (trimStmt , StatementSeparator )
157- if len (trimStmt ) == 0 {
158- return true
159- }
160-
161- result , err := transaction .Run (string (trimStmt ), nil )
162- if _ , err := neo4j .Collect (result , err ); err != nil {
163- stmtRunErr = err
164- return false
165- }
141+ tx , err := session .BeginTransaction (ctx )
142+ if err != nil {
143+ return err
144+ }
145+ defer func () {
146+ if cerr := tx .Close (ctx ); cerr != nil {
147+ err = multierror .Append (err , cerr )
148+ }
149+ }()
150+
151+ var stmtRunErr error
152+ if err := multistmt .Parse (migration , StatementSeparator , n .config .MultiStatementMaxSize , func (stmt []byte ) bool {
153+ trimStmt := bytes .TrimSpace (stmt )
154+ if len (trimStmt ) == 0 {
166155 return true
167- }); err != nil {
168- return nil , err
169156 }
170- return nil , stmtRunErr
171- })
172- return err
157+ trimStmt = bytes .TrimSuffix (trimStmt , StatementSeparator )
158+ if len (trimStmt ) == 0 {
159+ return true
160+ }
161+
162+ result , err := tx .Run (ctx , string (trimStmt ), nil )
163+ if _ , err := neo4j .CollectWithContext (ctx , result , err ); err != nil {
164+ stmtRunErr = err
165+ return false
166+ }
167+ return true
168+ }); err != nil {
169+ return err
170+ }
171+ return stmtRunErr
173172 }
174173
175174 body , err := io .ReadAll (migration )
176175 if err != nil {
177176 return err
178177 }
179178
180- _ , err = neo4j .Collect (session .Run (string (body [:]), nil ))
179+ res , err := session .Run (ctx , string (body [:]), nil )
180+ _ , err = neo4j .CollectWithContext (ctx , res , err )
181181 return err
182182}
183183
184184func (n * Neo4j ) SetVersion (version int , dirty bool ) (err error ) {
185- session := n .driver .NewSession (neo4j.SessionConfig {AccessMode : neo4j .AccessModeWrite })
185+ ctx := context .Background ()
186+ session := n .driver .NewSession (ctx , neo4j.SessionConfig {AccessMode : neo4j .AccessModeWrite })
186187 defer func () {
187- if cerr := session .Close (); cerr != nil {
188+ if cerr := session .Close (ctx ); cerr != nil {
188189 err = multierror .Append (err , cerr )
189190 }
190191 }()
191192
192193 query := fmt .Sprintf ("MERGE (sm:%s {version: $version}) SET sm.dirty = $dirty, sm.ts = datetime()" ,
193194 n .config .MigrationsLabel )
194- _ , err = neo4j .Collect (session .Run (query , map [string ]interface {}{"version" : version , "dirty" : dirty }))
195+ res , err := session .Run (ctx , query , map [string ]interface {}{"version" : version , "dirty" : dirty })
196+ _ , err = neo4j .CollectWithContext (ctx , res , err )
195197 if err != nil {
196198 return err
197199 }
@@ -204,75 +206,73 @@ type MigrationRecord struct {
204206}
205207
206208func (n * Neo4j ) Version () (version int , dirty bool , err error ) {
207- session := n .driver .NewSession (neo4j.SessionConfig {AccessMode : neo4j .AccessModeRead })
209+ ctx := context .Background ()
210+ session := n .driver .NewSession (ctx , neo4j.SessionConfig {AccessMode : neo4j .AccessModeRead })
208211 defer func () {
209- if cerr := session .Close (); cerr != nil {
212+ if cerr := session .Close (ctx ); cerr != nil {
210213 err = multierror .Append (err , cerr )
211214 }
212215 }()
213216
214217 query := fmt .Sprintf (`MATCH (sm:%s) RETURN sm.version AS version, sm.dirty AS dirty
215218ORDER BY COALESCE(sm.ts, datetime({year: 0})) DESC, sm.version DESC LIMIT 1` ,
216219 n .config .MigrationsLabel )
217- result , err := session .ReadTransaction (func (transaction neo4j.Transaction ) (interface {}, error ) {
218- result , err := transaction .Run (query , nil )
219- if err != nil {
220- return nil , err
221- }
222- if result .Next () {
223- record := result .Record ()
224- mr := MigrationRecord {}
225- versionResult , ok := record .Get ("version" )
226- if ! ok {
227- mr .Version = database .NilVersion
228- } else {
229- mr .Version = int (versionResult .(int64 ))
230- }
231220
232- dirtyResult , ok := record .Get ("dirty" )
233- if ok {
234- mr .Dirty = dirtyResult .(bool )
235- }
221+ tx , err := session .BeginTransaction (ctx )
236222
237- return mr , nil
238- }
239- return nil , result .Err ()
240- })
223+ result , err := tx .Run (ctx , query , nil )
241224 if err != nil {
242225 return database .NilVersion , false , err
243226 }
244- if result == nil {
245- return database .NilVersion , false , err
227+ if result .Next (ctx ) {
228+ record := result .Record ()
229+ mr := MigrationRecord {}
230+ versionResult , ok := record .Get ("version" )
231+ if ! ok {
232+ mr .Version = database .NilVersion
233+ } else {
234+ mr .Version = int (versionResult .(int64 ))
235+ }
236+
237+ dirtyResult , ok := record .Get ("dirty" )
238+ if ok {
239+ mr .Dirty = dirtyResult .(bool )
240+ }
241+
242+ return mr .Version , mr .Dirty , nil
246243 }
247- mr := result .( MigrationRecord )
248- return mr . Version , mr . Dirty , err
244+
245+ return database . NilVersion , false , err
249246}
250247
251248func (n * Neo4j ) Drop () (err error ) {
252- session := n .driver .NewSession (neo4j.SessionConfig {AccessMode : neo4j .AccessModeWrite })
249+ ctx := context .Background ()
250+ session := n .driver .NewSession (ctx , neo4j.SessionConfig {AccessMode : neo4j .AccessModeWrite })
253251 defer func () {
254- if cerr := session .Close (); cerr != nil {
252+ if cerr := session .Close (ctx ); cerr != nil {
255253 err = multierror .Append (err , cerr )
256254 }
257255 }()
258256
259- if _ , err := neo4j .Collect (session .Run ("MATCH (n) DETACH DELETE n" , nil )); err != nil {
257+ res , err := session .Run (ctx , "MATCH (n) DETACH DELETE n" , nil )
258+ if _ , err := neo4j .CollectWithContext (ctx , res , err ); err != nil {
260259 return err
261260 }
262261 return nil
263262}
264263
265264func (n * Neo4j ) ensureVersionConstraint () (err error ) {
266- session := n .driver .NewSession (neo4j.SessionConfig {AccessMode : neo4j .AccessModeWrite })
265+ ctx := context .Background ()
266+ session := n .driver .NewSession (ctx , neo4j.SessionConfig {AccessMode : neo4j .AccessModeWrite })
267267 defer func () {
268- if cerr := session .Close (); cerr != nil {
268+ if cerr := session .Close (ctx ); cerr != nil {
269269 err = multierror .Append (err , cerr )
270270 }
271271 }()
272272
273273 var neo4jVersion string
274-
275- res , err := neo4j .Collect ( session . Run ( "call dbms.components() yield versions unwind versions as version return version" , nil ) )
274+ result , err := session . Run ( ctx , "call dbms.components() yield versions unwind versions as version return version" , nil )
275+ res , err := neo4j .CollectWithContext ( ctx , result , err )
276276 if err != nil {
277277 return err
278278 }
@@ -287,7 +287,8 @@ func (n *Neo4j) ensureVersionConstraint() (err error) {
287287 using db.labels() to support Neo4j 3 and 4.
288288 Neo4J 3 doesn't support db.constraints() YIELD name
289289 */
290- res , err = neo4j .Collect (session .Run (fmt .Sprintf ("CALL db.labels() YIELD label WHERE label=\" %s\" RETURN label" , n .config .MigrationsLabel ), nil ))
290+ result , err = session .Run (ctx , fmt .Sprintf ("CALL db.labels() YIELD label WHERE label=\" %s\" RETURN label" , n .config .MigrationsLabel ), nil )
291+ res , err = neo4j .CollectWithContext (ctx , result , err )
291292 if err != nil {
292293 return err
293294 }
@@ -299,13 +300,14 @@ func (n *Neo4j) ensureVersionConstraint() (err error) {
299300 switch neo4jVersion {
300301 case "v5" :
301302 query = fmt .Sprintf ("CREATE CONSTRAINT FOR (a:%s) REQUIRE a.version IS UNIQUE" , n .config .MigrationsLabel )
302- case "v3" , " v4" :
303+ case "v4" :
303304 query = fmt .Sprintf ("CREATE CONSTRAINT ON (a:%s) ASSERT a.version IS UNIQUE" , n .config .MigrationsLabel )
304305 default :
305306 return fmt .Errorf ("unsupported neo4j version %v" , neo4jVersion )
306307 }
307308
308- if _ , err := neo4j .Collect (session .Run (query , nil )); err != nil {
309+ result , err = session .Run (ctx , query , nil )
310+ if _ , err := neo4j .CollectWithContext (ctx , result , err ); err != nil {
309311 return err
310312 }
311313 return nil
0 commit comments