@@ -18,50 +18,58 @@ class ORTModelBackend(onnxBytes: Array[Byte])
18
18
with ORTOperatorBackend
19
19
with AutoCloseable {
20
20
21
- def getInputAndOutputNodeNamesAndDims (sess : OrtSession ) = {
22
- val input_node_names = session.getInputNames
23
-
24
- val inputNodeDims = session.getInputInfo.values.asScala.map(_.getInfo.asInstanceOf [TensorInfo ].getShape)
21
+ def getInputAndOutputNodeNamesAndDims (sess : OrtSession ) = {
22
+ val input_node_names = session.getInputNames
25
23
26
- val output_node_names = session.getOutputNames
24
+ val inputNodeDims =
25
+ session.getInputInfo.values.asScala.map(_.getInfo.asInstanceOf [TensorInfo ].getShape)
27
26
28
- (input_node_names.asScala.toList, inputNodeDims.toArray, output_node_names.asScala.toList)
29
- }
27
+ val output_node_names = session.getOutputNames
30
28
31
- val session = getSession(onnxBytes)
29
+ (input_node_names.asScala.toList, inputNodeDims.toArray, output_node_names.asScala.toList)
30
+ }
32
31
33
- val allNodeNamesAndDims = getInputAndOutputNodeNamesAndDims(session )
32
+ val session = getSession(onnxBytes )
34
33
35
- override def fullModel [
36
- T <: Supported ,
37
- Tt <: TensorTypeDenotation , Td <: TensorShapeDenotation , S <: Shape
38
- ](
39
- inputs : Tuple
40
- )(using tt : ValueOf [Tt ], td : TensorShapeDenotationOf [Td ], s : ShapeOf [S ]): Tensor [T , Tuple3 [Tt , Td , S ]] = {
34
+ val allNodeNamesAndDims = getInputAndOutputNodeNamesAndDims(session)
41
35
42
- val size = inputs.size
43
- val inputTensors = (0 until size).map { i =>
44
- val tup = inputs.drop(i).take(1 )
45
- tup match { // Spurious warning here, see: https://github.com/lampepfl/dotty/issues/10318
36
+ override def fullModel [
37
+ T <: Supported ,
38
+ Tt <: TensorTypeDenotation ,
39
+ Td <: TensorShapeDenotation ,
40
+ S <: Shape
41
+ ](
42
+ inputs : Tuple
43
+ )(using
44
+ tt : ValueOf [Tt ],
45
+ td : TensorShapeDenotationOf [Td ],
46
+ s : ShapeOf [S ]
47
+ ): Tensor [T , Tuple3 [Tt , Td , S ]] = {
48
+
49
+ val size = inputs.size
50
+ val inputTensors = (0 until size).map { i =>
51
+ val tup = inputs.drop(i).take(1 )
52
+ tup match { // Spurious warning here, see: https://github.com/lampepfl/dotty/issues/10318
46
53
case t : Tuple1 [_] =>
47
- t(0 ) match {
48
- case tens : Tensor [T , Tuple3 [Tt , Td , S ]] => getOnnxTensor(tens.data, tens.shape, env)
49
- }
50
- }
51
- }.toArray
54
+ t(0 ) match {
55
+ case tens : Tensor [T , Tuple3 [Tt , Td , S ]] =>
56
+ getOnnxTensor(tens.data, tens.shape, env)
57
+ }
58
+ }
59
+ }.toArray
52
60
53
- val output = runModel[T , Tt , Td , S ](
54
- session,
55
- inputTensors,
56
- allNodeNamesAndDims._1,
57
- allNodeNamesAndDims._3
58
- )
61
+ val output = runModel[T , Tt , Td , S ](
62
+ session,
63
+ inputTensors,
64
+ allNodeNamesAndDims._1,
65
+ allNodeNamesAndDims._3
66
+ )
59
67
60
- output // .asInstanceOf[Tensor[T, Tuple3[Tt, Td, S]]]
61
- }
68
+ output // .asInstanceOf[Tensor[T, Tuple3[Tt, Td, S]]]
69
+ }
62
70
63
- override def close (): Unit = {
71
+ override def close (): Unit = {
64
72
// executable.close
65
73
// super.close
66
- }
74
+ }
67
75
}
0 commit comments