Skip to content

Commit 43a1576

Browse files
committed
Format
1 parent 25c6d92 commit 43a1576

File tree

6 files changed

+83
-72
lines changed

6 files changed

+83
-72
lines changed

backends/.js/src/main/scala/ORTWebModelBackend.scala

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -68,12 +68,14 @@ class ORTWebModelBackend(session: IO[InferenceSession]) extends Model() with ORT
6868
.map(_.toArray)
6969

7070
val output = inputTensors.flatMap { tns =>
71-
IO{runModel[T, Tt, Td, S](
72-
session,
73-
tns,
74-
inputNames,
75-
outputNames
76-
)}
71+
IO {
72+
runModel[T, Tt, Td, S](
73+
session,
74+
tns,
75+
inputNames,
76+
outputNames
77+
)
78+
}
7779
}
7880
output.flatten
7981
}

backends/.js/src/main/scala/ORTWebOperatorBackend.scala

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -144,16 +144,16 @@ trait ORTWebOperatorBackend extends OpToONNXBytesConverter {
144144
val result: IO[Tensor[T, Tuple3[Tt, Td, S]]] =
145145
for {
146146
mp <- modelProto.flatMap(IO.println("OpName => " + opName).as(_))
147-
} yield {
147+
} yield {
148148
// println(mp)
149-
callByteArrayOp(
150-
mp.toByteArray,
151-
inputs,
152-
IO.pure {
153-
mp.graph.map(_.input.map(_.name.getOrElse(""))).getOrElse(List[String]()).toList
154-
}
155-
)
156-
}
149+
callByteArrayOp(
150+
mp.toByteArray,
151+
inputs,
152+
IO.pure {
153+
mp.graph.map(_.input.map(_.name.getOrElse(""))).getOrElse(List[String]()).toList
154+
}
155+
)
156+
}
157157

158158
result.flatten
159159
}

backends/.jvm/src/main/scala/ORTModelBackend.scala

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -71,15 +71,18 @@ class ORTModelBackend(onnxBytes: Array[Byte])
7171
.sequence
7272
.map(_.toArray)
7373

74-
val output = cats.effect.Resource.make(inputTensors)(inTens => IO{inTens.map(_.close)}).use(inTens =>
75-
IO{runModel[T, Tt, Td, S](
76-
session,
77-
inTens,
78-
allNodeNamesAndDims._1,
79-
allNodeNamesAndDims._3
74+
val output = cats.effect.Resource
75+
.make(inputTensors)(inTens => IO { inTens.map(_.close) })
76+
.use(inTens =>
77+
IO {
78+
runModel[T, Tt, Td, S](
79+
session,
80+
inTens,
81+
allNodeNamesAndDims._1,
82+
allNodeNamesAndDims._3
83+
)
84+
}
8085
)
81-
}
82-
)
8386

8487
output.unsafeRunSync()
8588
}

backends/.jvm/src/main/scala/ORTOperatorBackend.scala

Lines changed: 46 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -59,24 +59,27 @@ trait ORTOperatorBackend extends OpToONNXBytesConverter with AutoCloseable {
5959
val tensorTypeDenotationFromType = tt.value
6060
val tensorShapeDenotationFromType = td.value
6161

62-
val tensArr: IO[Array[T]] = cats.effect.Resource.make(IO.blocking{sess.run(inputs)})(outTens => IO{outTens.close}).use(outTens =>
63-
{
64-
val firstOut = outTens.get(0).asInstanceOf[OnnxTensor]
65-
val shape = firstOut.getInfo.getShape.map(_.toInt)
62+
val tensArr: IO[Array[T]] = cats.effect.Resource
63+
.make(IO.blocking { sess.run(inputs) })(outTens => IO { outTens.close })
64+
.use(outTens => {
65+
val firstOut = outTens.get(0).asInstanceOf[OnnxTensor]
66+
val shape = firstOut.getInfo.getShape.map(_.toInt)
6667

67-
require(shape sameElements shapeFromType.toSeq)
68-
IO.blocking{getArrayFromOnnxTensor(firstOut)}
69-
}
70-
)
68+
require(shape sameElements shapeFromType.toSeq)
69+
IO.blocking { getArrayFromOnnxTensor(firstOut) }
70+
})
7171

7272
// TODO: Denotations
73-
val result: Tensor[T, Tuple3[Tt, Td, S]] = tensArr.map(x => Tensor(
74-
x,
75-
tensorTypeDenotationFromType,
76-
tensorShapeDenotationFromType,
77-
shapeFromType
78-
)
79-
).unsafeRunSync()
73+
val result: Tensor[T, Tuple3[Tt, Td, S]] = tensArr
74+
.map(x =>
75+
Tensor(
76+
x,
77+
tensorTypeDenotationFromType,
78+
tensorShapeDenotationFromType,
79+
shapeFromType
80+
)
81+
)
82+
.unsafeRunSync()
8083
// result.flatMap(IO.println("Invoking run").as(_))
8184
result
8285
}
@@ -145,21 +148,24 @@ trait ORTOperatorBackend extends OpToONNXBytesConverter with AutoCloseable {
145148
def res: Tensor[T, Tuple3[Tt, Td, S]] = {
146149
// val resource = cats.effect.Resource.make(IO{getSession(opModel)})(sess => IO{sess.close})
147150
// resource.use( sess =>
148-
cats.effect.Resource.make(inputTensors)(inTens => IO{inTens.map(_.close)}).use(inTens =>
149-
input_node_names.flatMap { y =>
150-
cats.effect.Resource
151-
.make(IO.blocking(getSession(opModel)))(sess => IO { sess.close })
152-
.use(sess =>
153-
IO{runModel(
154-
sess,
155-
inTens,
156-
y,
157-
output_node_names
151+
cats.effect.Resource
152+
.make(inputTensors)(inTens => IO { inTens.map(_.close) })
153+
.use(inTens =>
154+
input_node_names.flatMap { y =>
155+
cats.effect.Resource
156+
.make(IO.blocking(getSession(opModel)))(sess => IO { sess.close })
157+
.use(sess =>
158+
IO {
159+
runModel(
160+
sess,
161+
inTens,
162+
y,
163+
output_node_names
164+
)
165+
}
158166
)
159-
}
160-
)
161-
}
162-
)
167+
}
168+
)
163169
}.unsafeRunSync()
164170
// res.flatMap(IO.println("Post run").as(_))
165171
res
@@ -184,16 +190,17 @@ trait ORTOperatorBackend extends OpToONNXBytesConverter with AutoCloseable {
184190

185191
val result: IO[Tensor[T, Tuple3[Tt, Td, S]]] =
186192
for {
187-
mp <- modelProto //modelProto.flatMap(IO.println("OpName => " + opName).as(_))
193+
mp <- modelProto // modelProto.flatMap(IO.println("OpName => " + opName).as(_))
188194
} yield callByteArrayOp(
189-
mp.toByteArray,
190-
inputs,
191-
IO.pure {
192-
mp.graph.map(_.input.map(_.name.getOrElse(""))).getOrElse(List[String]()).toList
193-
}
194-
)
195-
val r = result.unsafeRunSync() // If don't use unsafe here, we get redundant callOp invocations. If we memoize w/ unsafe, we leak memory.
196-
r //This approach makes callOp sync/eager again.
195+
mp.toByteArray,
196+
inputs,
197+
IO.pure {
198+
mp.graph.map(_.input.map(_.name.getOrElse(""))).getOrElse(List[String]()).toList
199+
}
200+
)
201+
val r =
202+
result.unsafeRunSync() // If don't use unsafe here, we get redundant callOp invocations. If we memoize w/ unsafe, we leak memory.
203+
r // This approach makes callOp sync/eager again.
197204
}
198205

199206
def modelToPersist(mod: ModelProto, outName: String) = {
@@ -204,6 +211,5 @@ trait ORTOperatorBackend extends OpToONNXBytesConverter with AutoCloseable {
204211
mod.clearGraph.withGraph(graphToPersist)
205212
}
206213

207-
override def close(): Unit = {
208-
}
214+
override def close(): Unit = {}
209215
}

core/src/main/scala/OpToONNXBytesConverter.scala

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -269,10 +269,10 @@ trait OpToONNXBytesConverter {
269269
t
270270
}
271271
Some(
272-
createInputValueInfoProto(in, name).map { inf =>
273-
(inf, name)
274-
}
275-
)
272+
createInputValueInfoProto(in, name).map { inf =>
273+
(inf, name)
274+
}
275+
)
276276

277277
}
278278
case None => None
@@ -286,10 +286,10 @@ trait OpToONNXBytesConverter {
286286
t
287287
}
288288
Some(
289-
createInputValueInfoProto(tens, name).map { inf =>
290-
(inf, name)
291-
}
292-
)
289+
createInputValueInfoProto(tens, name).map { inf =>
290+
(inf, name)
291+
}
292+
)
293293
}
294294
}
295295
}

core/src/main/scala/Tensors.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ object Tensors {
3131
// Need this alias to not conflict with other Tensors
3232
// TODO: consider using TF-Java ndarray as backing instead of Scala Array here
3333
// S is overloaded
34-
type Tensor[T <: Supported, +Ax <: Axes] = IO[InnerTensor[T, Ax]]
34+
type Tensor[T <: Supported, +Ax <: Axes] = IO[InnerTensor[T, Ax]]
3535
opaque type InnerTensor[T <: Supported, +Ax <: Axes] = Tuple2[Array[T], Ax]
3636

3737
type SparseTensor[T <: Supported, A <: Axes] = Tensor[T, A]

0 commit comments

Comments
 (0)