@@ -9,7 +9,9 @@ import fs2.Stream
99import io .getquill .NamingStrategy
1010import io .getquill .context .sql .idiom .SqlIdiom
1111import io .getquill .context .StreamingContext
12+ import io .getquill .context .ExecutionInfo
1213import java .sql .Connection
14+ import scala .annotation .nowarn
1315import scala .util .Success
1416import scala .util .Try
1517import doobie .enumerated .AutoGeneratedKeys
@@ -38,56 +40,84 @@ trait DoobieContextBase[Dialect <: SqlIdiom, Naming <: NamingStrategy]
3840 // to log.underlying below.
3941 private val log : ContextLogger = new ContextLogger (" DoobieContext" )
4042
43+ private def useConnection [A ](f : Connection => PreparedStatementIO [A ]): PreparedStatementIO [A ] =
44+ FPS .getConnection.flatMap(f)
45+
4146 private def prepareAndLog (
4247 sql : String ,
4348 p : Prepare ,
44- ): PreparedStatementIO [Unit ] = FPS .raw(p).flatMap { case (params, _) =>
49+ )(
50+ implicit connection : Connection
51+ ): PreparedStatementIO [Unit ] = FPS .raw(p(_, connection)).flatMap { case (params, _) =>
4552 FPS .delay(log.logQuery(sql, params))
4653 }
4754
4855 override def executeQuery [A ](
4956 sql : String ,
5057 prepare : Prepare = identityPrepare,
5158 extractor : Extractor [A ] = identityExtractor,
59+ )(
60+ info : ExecutionInfo ,
61+ dc : DatasourceContext ,
5262 ): ConnectionIO [List [A ]] =
5363 HC .prepareStatement(sql) {
54- prepareAndLog(sql, prepare) *>
55- HPS .executeQuery {
56- HRS .list(extractor)
57- }
64+ useConnection { implicit connection =>
65+ prepareAndLog(sql, prepare) *>
66+ HPS .executeQuery {
67+ HRS .list(extractor)
68+ }
69+ }
5870 }
5971
6072 override def executeQuerySingle [A ](
6173 sql : String ,
6274 prepare : Prepare = identityPrepare,
6375 extractor : Extractor [A ] = identityExtractor,
76+ )(
77+ info : ExecutionInfo ,
78+ dc : DatasourceContext ,
6479 ): ConnectionIO [A ] =
6580 HC .prepareStatement(sql) {
66- prepareAndLog(sql, prepare) *>
67- HPS .executeQuery {
68- HRS .getUnique(extractor)
69- }
81+ useConnection { implicit connection =>
82+ prepareAndLog(sql, prepare) *>
83+ HPS .executeQuery {
84+ HRS .getUnique(extractor)
85+ }
86+ }
7087 }
7188
89+ @ nowarn(" msg=is never used" )
7290 def streamQuery [A ](
7391 fetchSize : Option [Int ],
7492 sql : String ,
7593 prepare : Prepare = identityPrepare,
7694 extractor : Extractor [A ] = identityExtractor,
95+ )(
96+ info : ExecutionInfo ,
97+ dc : DatasourceContext ,
7798 ): Stream [ConnectionIO , A ] =
78- HC .stream(
79- sql,
80- prepareAndLog(sql, prepare),
81- fetchSize.getOrElse(DefaultChunkSize ),
82- )(extractor)
99+ for {
100+ connection <- Stream .eval(FC .raw(identity))
101+ result <-
102+ HC .stream(
103+ sql,
104+ prepareAndLog(sql, prepare)(connection),
105+ fetchSize.getOrElse(DefaultChunkSize ),
106+ )(extractorToRead(extractor)(connection))
107+ } yield result
83108
84109 override def executeAction [A ](
85110 sql : String ,
86111 prepare : Prepare = identityPrepare,
112+ )(
113+ info : ExecutionInfo ,
114+ dc : DatasourceContext ,
87115 ): ConnectionIO [Long ] =
88116 HC .prepareStatement(sql) {
89- prepareAndLog(sql, prepare) *>
90- HPS .executeUpdate.map(_.toLong)
117+ useConnection { implicit connection =>
118+ prepareAndLog(sql, prepare) *>
119+ HPS .executeUpdate.map(_.toLong)
120+ }
91121 }
92122
93123 private def prepareConnections [A ](returningBehavior : ReturnAction ) =
@@ -103,42 +133,68 @@ trait DoobieContextBase[Dialect <: SqlIdiom, Naming <: NamingStrategy]
103133 prepare : Prepare = identityPrepare,
104134 extractor : Extractor [A ],
105135 returningBehavior : ReturnAction ,
136+ )(
137+ info : ExecutionInfo ,
138+ dc : DatasourceContext ,
106139 ): ConnectionIO [A ] =
107140 prepareConnections[A ](returningBehavior)(sql) {
108- prepareAndLog(sql, prepare) *>
109- FPS .executeUpdate *>
110- HPS .getGeneratedKeys(HRS .getUnique(extractor))
141+ useConnection { implicit connection =>
142+ prepareAndLog(sql, prepare) *>
143+ FPS .executeUpdate *>
144+ HPS .getGeneratedKeys(HRS .getUnique(extractor))
145+ }
111146 }
112147
113- private def prepareBatchAndLog (sql : String , p : Prepare ): PreparedStatementIO [Unit ] =
114- FPS .raw(p) flatMap { case (params, _) => FPS .delay(log.logBatchItem(sql, params)) }
148+ private def prepareBatchAndLog (
149+ sql : String ,
150+ p : Prepare ,
151+ )(
152+ implicit connection : Connection
153+ ): PreparedStatementIO [Unit ] =
154+ FPS .raw(p(_, connection)) flatMap { case (params, _) =>
155+ FPS .delay(log.logBatchItem(sql, params))
156+ }
115157
116158 override def executeBatchAction (
117159 groups : List [BatchGroup ]
160+ )(
161+ info : ExecutionInfo ,
162+ dc : DatasourceContext ,
118163 ): ConnectionIO [List [Long ]] = groups.flatTraverse { case BatchGroup (sql, preps) =>
119164 HC .prepareStatement(sql) {
120- FPS .delay(log.underlying.debug(" Batch: {}" , sql)) *>
121- preps.traverse(prepareBatchAndLog(sql, _) *> FPS .addBatch) *>
122- Nested (HPS .executeBatch).map(_.toLong).value
165+ useConnection { implicit connection =>
166+ FPS .delay(log.underlying.debug(" Batch: {}" , sql)) *>
167+ preps.traverse(prepareBatchAndLog(sql, _) *> FPS .addBatch) *>
168+ Nested (HPS .executeBatch).map(_.toLong).value
169+ }
123170 }
124171 }
125172
126173 override def executeBatchActionReturning [A ](
127174 groups : List [BatchGroupReturning ],
128175 extractor : Extractor [A ],
176+ )(
177+ info : ExecutionInfo ,
178+ dc : DatasourceContext ,
129179 ): ConnectionIO [List [A ]] = groups.flatTraverse {
130180 case BatchGroupReturning (sql, returningBehavior, preps) =>
131181 prepareConnections(returningBehavior)(sql) {
132- FPS .delay(log.underlying.debug(" Batch: {}" , sql)) *>
133- preps.traverse(prepareBatchAndLog(sql, _) *> FPS .addBatch) *>
134- HPS .executeBatch *>
135- HPS .getGeneratedKeys(HRS .list(extractor))
182+
183+ useConnection { implicit connection =>
184+ FPS .delay(log.underlying.debug(" Batch: {}" , sql)) *>
185+ preps.traverse(prepareBatchAndLog(sql, _) *> FPS .addBatch) *>
186+ HPS .executeBatch *>
187+ HPS .getGeneratedKeys(HRS .list(extractor))
188+ }
136189 }
137190 }
138191
139192 // Turn an extractor into a `Read` so we can use the existing resultset.
140- private implicit def extractorToRead [A ](ex : Extractor [A ]): Read [A ] =
141- new Read [A ](Nil , (rs, _) => ex(rs))
193+ private implicit def extractorToRead [A ](
194+ ex : Extractor [A ]
195+ )(
196+ implicit connection : Connection
197+ ): Read [A ] = new Read [A ](Nil , (rs, _) => ex(rs, connection))
142198
143199 // Nothing to do here.
144200 override def close (): Unit = ()
0 commit comments