Skip to content

Commit 7ea0005

Browse files
committed
Clean up / triage TODOs; Fix random seed
1 parent 569ff0b commit 7ea0005

File tree

5 files changed

+20
-30
lines changed

5 files changed

+20
-30
lines changed

backends/src/main/scala/NGraphBackend.scala

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -32,11 +32,12 @@ import org.bytedeco.ngraph.global.ngraph.i64
3232
import org.bytedeco.ngraph.global.ngraph.i32
3333
import org.bytedeco.onnx.global.onnx.check_model
3434

35-
//TODO: Extract ModelProto modifications into generic layer, then do each layer-wise
35+
//TODEFER: Tracing Mode: Extract ModelProto modifications into generic layer, then do each layer-wise
3636
// op in a lazy fashion, while doing the same for generating a single overall ModelProto, via ModelProto.MergeFrom.
37+
// Between fine and full modes:
3738
// Use one path for speed and dynamic graph tracing at runtime (the default), the other for sanity/type/shape/AxisType/control flow checking at compile time
38-
//TODO: ONNX-JS backend for both JS and JVM
39-
//TODO: ONNX Runtime backend for JVM (and Native?)
39+
//TODEFER: ONNX-JS backend for both JS and JVM
40+
//TODEFER: ONNX Runtime backend for JVM (and Native?)
4041
class NGraphBackend(onnxBytes: Array[Byte])
4142
extends Add
4243
with DataSource
@@ -724,9 +725,7 @@ class NGraphBackend(onnxBytes: Array[Byte])
724725
outName: String,
725726
attrs: Map[String, Any]
726727
): (ModelProto) = {
727-
//TODO: Refactor op method sigs to return this
728728

729-
//TODO: Fix ModelProto leaking memory here
730729
val model = (new ModelProto).New()
731730
val graph = new org.bytedeco.onnx.GraphProto
732731
model.set_producer_name("ONNX-Scala")
@@ -757,7 +756,7 @@ class NGraphBackend(onnxBytes: Array[Byte])
757756
addInputToGraph(B, bName, graph)
758757
addInputToGraph(C, cName, graph)
759758

760-
//TODO: ensure the outer model is the last merged
759+
//TODEFER: Merge models, ensuring the outer model is the last merged
761760
(model)
762761
}
763762

@@ -864,9 +863,6 @@ class NGraphBackend(onnxBytes: Array[Byte])
864863
C: Option[Tensor[T2]]
865864
): (Tensor[T3]) = {
866865
val scope = new PointerScope()
867-
868-
//println(modelString.getString)
869-
//TODO: Pull this as far forward as possible
870866
val modelString = new BytePointer(opModel: _*)
871867
val ngraphFunc = import_onnx_model(modelString)
872868
modelString.close

programGenerator/src/main/scala/ONNXProgramGenerator.scala

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,9 @@ import collection.JavaConverters._
2525

2626
import scala.reflect.io.Streamable
2727

28-
//TODO ASAP: Fix references to UnionTypes and ONNXHelper init in generated programs
28+
//TODEFER: Use Squid to clean up / improve this
2929

30-
//TODO: Use Squid to clean up / improve this
31-
32-
//TODO: de-tuple on the left hand side when there are multiple outputs . should also solved the other output TODOs
30+
//TODEFER: de-tuple on the left hand side when there are multiple outputs . should also solve the other output TODOs
3331
object ONNXProgramGenerator {
3432
def main(args: Array[String]): Unit = {
3533

@@ -65,8 +63,8 @@ object ONNXProgramGenerator {
6563
val useZIO = false
6664
val useDotty = false
6765

68-
//TODO: Fix output for the benchmark models shown here: https://github.com/onnx/backend-scoreboard
69-
//TODO: run time benchmarks on the same models
66+
//TODO: Test outputs for the benchmark models shown here: https://github.com/onnx/backend-scoreboard
67+
//TODEFER: run time benchmarks on the same models
7068

7169
val programName = fileName.stripSuffix(".onnx").capitalize + (if (useZIO)
7270
"ZIO"
@@ -75,7 +73,7 @@ object ONNXProgramGenerator {
7573
"programGenerator/src/main/scala/gen/" + programName + ".scala"
7674
);
7775

78-
//TODO: Be explicit about model version, metadata
76+
//TODEFER: Be explicit about model version, metadata
7977
//Notes, from the standard:
8078
//"Each model MUST explicitly name the operator sets that it relies on for its functionality."
8179
//"An implementation must support all operators in the set or reject the model" - This can happen at runtime via ZIOstyle implicits, possibly mixing backends
@@ -330,7 +328,7 @@ object ONNXProgramGenerator {
330328
else
331329
"List(") + opName + (if (useZIO)
332330
"ZIO"
333-
else //TODO: Tell the compiler when one of the type params in an output type
331+
else //TODO: Tell the compiler when one of the type params is an output type
334332
"") + "." + opName + sinceVersion + (if (useZIO)
335333
"ZIO"
336334
else

zio/src/main/scala/NCFZIO.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ import scala.reflect.io.Streamable
1515
import org.bytedeco.javacpp.PointerScope
1616
import org.bytedeco.javacpp.Pointer
1717

18-
//TODO: Add changes to generator; Generate both full model and layerwise programs each time
18+
//TODO: Generate full model as well as layerwise
1919
class NCFZIO(byteArray: Array[Byte], userIdsMap: Map[Long, Long], itemIdsMap: Map[Long, Long])
2020
extends AutoCloseable {
2121

zio/src/main/scala/NCFZIOFine.scala

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@ import scala.language.higherKinds
1313
import scala.io.Source
1414
import scala.reflect.io.Streamable
1515

16-
//TODO: Add changes to generator; Generate both full model and layerwise programs each time
1716
class NCFZIOFineGrained(
1817
byteArray: Array[Byte],
1918
userIdsMap: Map[Long, Long],
@@ -26,7 +25,7 @@ class NCFZIOFineGrained(
2625
val GemmZIO: GemmZIO = new ONNXNGraphHandlers(byteArray)
2726
val ReluZIO: ReluZIO = new ONNXNGraphHandlers(byteArray)
2827
val SigmoidZIO: SigmoidZIO = new ONNXNGraphHandlers(byteArray)
29-
val dataSource: DataSourceZIO = ??? //TODO: Inject
28+
val dataSource: DataSourceZIO = new ONNXNGraphHandlers(byteArray)
3029

3130
def fineNCF(
3231
inputDataactual_input_1: Task[Tensor[Long]],

zio/src/main/scala/ZIONGraphBackend.scala

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ class ONNXNGraphHandlers(onnxBytes: Array[Byte])
3333
with ReluZIO
3434
with ConcatZIO
3535
with DropoutZIO
36-
// with DataSourceZIO
36+
with DataSourceZIO
3737
with AutoCloseable {
3838
val scope = new PointerScope()
3939
val ngraphBackend = new NGraphBackend(onnxBytes)
@@ -46,18 +46,14 @@ class ONNXNGraphHandlers(onnxBytes: Array[Byte])
4646
ngraphBackend.fullModel[T, T1, T2, T3](A, B, C)
4747
}
4848

49-
/*
50-
def getParamsZIO[T: Numeric: ClassTag](name: String)(
51-
implicit ev: (UNil TypeOr Float16 TypeOr Float TypeOr Double TypeOr Byte TypeOr Short TypeOr Int TypeOr Long TypeOr UByte TypeOr UShort TypeOr UInt TypeOr ULong TypeOr Complex[
52-
Float
53-
] TypeOr Complex[Double])#check[T]
54-
): Task[Tensor[T]] = {
49+
def getParamsZIO[T: Numeric: ClassTag](name: String): Task[Tensor[T]] =
50+
{
5551
Task {
5652
ngraphBackend.getParams(name)
5753
}
5854

5955
}
60-
*/
56+
6157
def getAttributesZIO[T: Numeric: ClassTag](name: String)(
6258
implicit evT: Contains[
6359
T,
@@ -340,7 +336,8 @@ object ZIONGraphMain extends App {
340336

341337
val userIds = userIdsMap.keys.toArray
342338
val itemIds = itemIdsMap.keys.toArray
343-
def getRandomId(arr: Array[Long]) = arr(Random.nextInt(arr.length))
339+
val rand = new Random(42)
340+
def getRandomId(arr: Array[Long]) = arr(rand.nextInt(arr.length))
344341

345342
val input = Task {
346343
val tens = TensorFactory.getTensor(
@@ -388,7 +385,7 @@ object ZIONGraphMain extends App {
388385
println(Pointer.maxPhysicalBytes)
389386
println("Output size: " + output2._1.size)
390387
println("Output 0: " + output2._1(0))
391-
println("Output 1: " + output2._1(7999)) //TODO: Investigate output flipping here, possibly due to race
388+
println("Output 1: " + output2._1(7999))
392389
// println("Output 2: " + output2._1(2))
393390
// println("Output 3: " + output2._1(3))
394391
// println("Output 4: " + output2._1(4))

0 commit comments

Comments
 (0)