Skip to content

Commit 0909018

Browse files
committed
Format
1 parent d4f4ec1 commit 0909018

22 files changed

+5293
-3772
lines changed

.scalafmt.conf

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1-
align = more
1+
align.preset = "more"
22
maxColumn = 100
3-
version=2.7.5
3+
version=3.0.0-RC2
4+
runner.dialect = scala3
5+
indent.main = 3

backends/.jvm/src/main/scala/NCF.scala

Lines changed: 29 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -19,38 +19,35 @@ import io.kjaer.compiletime._
1919
class NCF(byteArray: Array[Byte], userIdsMap: Map[Long, Long], itemIdsMap: Map[Long, Long])
2020
extends AutoCloseable {
2121

22+
val fullORTBackend = new ORTModelBackend(byteArray)
2223

23-
val fullORTBackend = new ORTModelBackend(byteArray)
24-
25-
def fullNCF(
26-
inputDataactual_input_1: Tensor[Long, Axes],
27-
inputDatalearned_0: Tensor[Long, Axes]
28-
): Tensor[Float, Axes] = {
24+
def fullNCF(
25+
inputDataactual_input_1: Tensor[Long, Axes],
26+
inputDatalearned_0: Tensor[Long, Axes]
27+
): Tensor[Float, Axes] = {
2928
// val scope = new PointerScope()
30-
val nodeactual_input_1 = Tuple1((
31-
inputDataactual_input_1.data.map(y => userIdsMap(y)),
32-
inputDataactual_input_1.shape)
33-
)
34-
35-
val tensorType: String with Singleton = "TensorType"
36-
val nodelearned_0 = Tuple1((inputDatalearned_0.data.map(y => itemIdsMap(y)), inputDatalearned_0.shape))
37-
38-
//Note: Don't need to specify all the type params except in Dotty
39-
val nodeFullOutput: Tensor[Float, Axes] =
40-
fullORTBackend
41-
.fullModel[Float, "TensorType", "DimensionDenotation" ##: TSNil, 1 #: 1000 #: SNil](
42-
//TODO: testing less than enough inputs
43-
(nodeactual_input_1)
44-
)
45-
46-
47-
48-
49-
nodeFullOutput //.asInstanceOf[Tensor[Float]] //Bad
50-
}
51-
52-
override def close(): Unit = {
53-
fullORTBackend.close()
54-
55-
}
29+
val nodeactual_input_1 = Tuple1(
30+
(inputDataactual_input_1.data.map(y => userIdsMap(y)), inputDataactual_input_1.shape)
31+
)
32+
33+
val tensorType: String with Singleton = "TensorType"
34+
val nodelearned_0 = Tuple1(
35+
(inputDatalearned_0.data.map(y => itemIdsMap(y)), inputDatalearned_0.shape)
36+
)
37+
38+
//Note: Don't need to specify all the type params except in Dotty
39+
val nodeFullOutput: Tensor[Float, Axes] =
40+
fullORTBackend
41+
.fullModel[Float, "TensorType", "DimensionDenotation" ##: TSNil, 1 #: 1000 #: SNil](
42+
//TODO: testing less than enough inputs
43+
(nodeactual_input_1)
44+
)
45+
46+
nodeFullOutput //.asInstanceOf[Tensor[Float]] //Bad
47+
}
48+
49+
override def close(): Unit = {
50+
fullORTBackend.close()
51+
52+
}
5653
}

backends/.jvm/src/main/scala/ORTModelBackend.scala

Lines changed: 42 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -18,50 +18,58 @@ class ORTModelBackend(onnxBytes: Array[Byte])
1818
with ORTOperatorBackend
1919
with AutoCloseable {
2020

21-
def getInputAndOutputNodeNamesAndDims(sess: OrtSession) = {
22-
val input_node_names = session.getInputNames
23-
24-
val inputNodeDims = session.getInputInfo.values.asScala.map(_.getInfo.asInstanceOf[TensorInfo].getShape)
21+
def getInputAndOutputNodeNamesAndDims(sess: OrtSession) = {
22+
val input_node_names = session.getInputNames
2523

26-
val output_node_names = session.getOutputNames
24+
val inputNodeDims =
25+
session.getInputInfo.values.asScala.map(_.getInfo.asInstanceOf[TensorInfo].getShape)
2726

28-
(input_node_names.asScala.toList, inputNodeDims.toArray, output_node_names.asScala.toList)
29-
}
27+
val output_node_names = session.getOutputNames
3028

31-
val session = getSession(onnxBytes)
29+
(input_node_names.asScala.toList, inputNodeDims.toArray, output_node_names.asScala.toList)
30+
}
3231

33-
val allNodeNamesAndDims = getInputAndOutputNodeNamesAndDims(session)
32+
val session = getSession(onnxBytes)
3433

35-
override def fullModel[
36-
T <: Supported,
37-
Tt <: TensorTypeDenotation, Td <: TensorShapeDenotation, S <: Shape
38-
](
39-
inputs: Tuple
40-
)(using tt: ValueOf[Tt], td: TensorShapeDenotationOf[Td], s: ShapeOf[S]): Tensor[T, Tuple3[Tt, Td, S]] = {
34+
val allNodeNamesAndDims = getInputAndOutputNodeNamesAndDims(session)
4135

42-
val size = inputs.size
43-
val inputTensors = (0 until size).map { i =>
44-
val tup = inputs.drop(i).take(1)
45-
tup match { //Spurious warning here, see: https://github.com/lampepfl/dotty/issues/10318
36+
override def fullModel[
37+
T <: Supported,
38+
Tt <: TensorTypeDenotation,
39+
Td <: TensorShapeDenotation,
40+
S <: Shape
41+
](
42+
inputs: Tuple
43+
)(using
44+
tt: ValueOf[Tt],
45+
td: TensorShapeDenotationOf[Td],
46+
s: ShapeOf[S]
47+
): Tensor[T, Tuple3[Tt, Td, S]] = {
48+
49+
val size = inputs.size
50+
val inputTensors = (0 until size).map { i =>
51+
val tup = inputs.drop(i).take(1)
52+
tup match { //Spurious warning here, see: https://github.com/lampepfl/dotty/issues/10318
4653
case t: Tuple1[_] =>
47-
t(0) match {
48-
case tens: Tensor[T, Tuple3[Tt, Td, S]] => getOnnxTensor(tens.data, tens.shape, env)
49-
}
50-
}
51-
}.toArray
54+
t(0) match {
55+
case tens: Tensor[T, Tuple3[Tt, Td, S]] =>
56+
getOnnxTensor(tens.data, tens.shape, env)
57+
}
58+
}
59+
}.toArray
5260

53-
val output = runModel[T, Tt, Td, S](
54-
session,
55-
inputTensors,
56-
allNodeNamesAndDims._1,
57-
allNodeNamesAndDims._3
58-
)
61+
val output = runModel[T, Tt, Td, S](
62+
session,
63+
inputTensors,
64+
allNodeNamesAndDims._1,
65+
allNodeNamesAndDims._3
66+
)
5967

60-
output //.asInstanceOf[Tensor[T, Tuple3[Tt, Td, S]]]
61-
}
68+
output //.asInstanceOf[Tensor[T, Tuple3[Tt, Td, S]]]
69+
}
6270

63-
override def close(): Unit = {
71+
override def close(): Unit = {
6472
// executable.close
6573
// super.close
66-
}
74+
}
6775
}

0 commit comments

Comments
 (0)