Skip to content

Commit 053c60f

Browse files
committed
Format
1 parent a3ad5e1 commit 053c60f

File tree

8 files changed

+969
-911
lines changed

8 files changed

+969
-911
lines changed

backends/src/main/scala/NCF213.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,13 +36,14 @@ class NCF(byteArray: Array[Byte], userIdsMap: Map[Long, Long], itemIdsMap: Map[L
3636
//Note: Don't need to specify all the type params except in Dotty
3737
val nodeFullOutput: Tensor[Float] =
3838
fullNgraphHandler
39-
.fullModel[Option[Tensor[Long]], Option[Tensor[Long]], Any, Any, Any, Any, Any, Any, Any, Tensor[
39+
.fullModel[Option[Tensor[Long]], Option[
40+
Tensor[Long]
41+
], Any, Any, Any, Any, Any, Any, Any, Tensor[
4042
Float
4143
], Any, Any, Any, Any, Any, Any, Any, Any](
4244
(Some(nodeactual_input_1), Some(nodelearned_0), None, None, None, None, None, None, None)
4345
)
4446

45-
4647
System.runFinalization
4748
nodeFullOutput
4849
}

backends/src/main/scala/ORTModelBackend213.scala

Lines changed: 30 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@ import org.bytedeco.onnxruntime.global.onnxruntime._
99

1010
import org.emergentorder.onnx._
1111

12-
1312
//TODO: Clean up, remove asInstaceOf, multiple inputs, etc.
1413
class ORTModelBackend(onnxBytes: Array[Byte])
1514
extends Model(onnxBytes)
@@ -56,27 +55,27 @@ class ORTModelBackend(onnxBytes: Array[Byte])
5655
val allNodeNamesAndDims = getInputAndOutputNodeNamesAndDims(session)
5756

5857
override def fullModel[
59-
T: ClassTag,
60-
T1: ClassTag,
61-
T2: ClassTag,
62-
T3: ClassTag,
63-
T4: ClassTag,
64-
T5: ClassTag,
65-
T6: ClassTag,
66-
T7: ClassTag,
67-
T8: ClassTag,
68-
T9: ClassTag,
69-
T10: ClassTag,
70-
T11: ClassTag,
71-
T12: ClassTag,
72-
T13: ClassTag,
73-
T14: ClassTag,
74-
T15: ClassTag,
75-
T16: ClassTag,
76-
T17: ClassTag
77-
](
78-
inputs: Tuple9[T, T1, T2, T3, T4, T5, T6, T7, T8]
79-
): (T9) = {
58+
T: ClassTag,
59+
T1: ClassTag,
60+
T2: ClassTag,
61+
T3: ClassTag,
62+
T4: ClassTag,
63+
T5: ClassTag,
64+
T6: ClassTag,
65+
T7: ClassTag,
66+
T8: ClassTag,
67+
T9: ClassTag,
68+
T10: ClassTag,
69+
T11: ClassTag,
70+
T12: ClassTag,
71+
T13: ClassTag,
72+
T14: ClassTag,
73+
T15: ClassTag,
74+
T16: ClassTag,
75+
T17: ClassTag
76+
](
77+
inputs: Tuple9[T, T1, T2, T3, T4, T5, T6, T7, T8]
78+
): (T9) = {
8079

8180
val inputTensors = Array(
8281
getInput(inputs._1),
@@ -100,20 +99,20 @@ class ORTModelBackend(onnxBytes: Array[Byte])
10099
// val outputPointer = out.get(0).GetTensorMutableDataFloat().capacity(inputs.GetTensorTypeAndShapeInfo().GetElementCount());
101100

102101
// println(outputPointer.get(0).IsTensor())
103-
102+
104103
output.asInstanceOf[T9]
105104
}
106105

107106
def getInput[T: ClassTag](
108-
input: T
109-
): Option[Value] = {
110-
input match {
111-
case tensorOpt: Option[Tensor[Any]] => {
112-
tensorOpt match {
113-
case Some(x) => Some(getTensor(x))
114-
case None => None
115-
}
107+
input: T
108+
): Option[Value] = {
109+
input match {
110+
case tensorOpt: Option[Tensor[Any]] => {
111+
tensorOpt match {
112+
case Some(x) => Some(getTensor(x))
113+
case None => None
116114
}
115+
}
117116
case tensor: Tensor[Any] => {
118117
Some(getTensor(tensor))
119118
}

0 commit comments

Comments
 (0)