Skip to content

Commit 62ba8f3

Browse files
committed
Fix MatMul shape constraints
1 parent b2cc80c commit 62ba8f3

File tree

2 files changed

+8
-10
lines changed

2 files changed

+8
-10
lines changed

build.sbt

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@ import sbtcrossproject.CrossPlugin.autoImport.{crossProject, CrossType}
33
val dottyVersion = "3.0.0-RC1"
44
val scala213Version = "2.13.5"
55
val spireVersion = "0.17.0"
6-
val scalametaVersion = "4.4.10"
76

87
scalaVersion := dottyVersion
98

@@ -17,12 +16,11 @@ lazy val commonSettings = Seq(
1716
updateOptions := updateOptions.value.withLatestSnapshots(false),
1817
scalacOptions ++= Seq("-feature", "-unchecked", "-deprecation"),
1918
autoCompilerPlugins := true,
20-
sources in (Compile, doc) := Seq(), //Bug w/ Dotty & JS on doc
19+
// sources in (Compile, doc) := Seq(), //Bug w/ Dotty & JS on doc
2120
) ++ sonatypeSettings
2221

2322
lazy val common = (crossProject(JSPlatform, JVMPlatform)
2423
.crossType(CrossType.Pure) in file("common"))
25-
// .enablePlugins(ScalaJSPlugin)
2624
.settings(commonSettings, name := "onnx-scala-common",
2725
crossScalaVersions := Seq(
2826
dottyVersion,
@@ -56,7 +54,7 @@ lazy val backends = (crossProject(JSPlatform, JVMPlatform)
5654
.crossType(CrossType.Pure) in file("backends"))
5755
.dependsOn(core)
5856
//conditionally enabling/disable based on version, still not working
59-
// .enablePlugins(ScalaJSBundlerPlugin) //{ScalablyTypedConverterPlugin})
57+
.enablePlugins(ScalaJSBundlerPlugin)//, ScalablyTypedConverterPlugin)
6058
.settings(
6159
commonSettings,
6260
name := "onnx-scala-backends",
@@ -79,10 +77,10 @@ lazy val backends = (crossProject(JSPlatform, JVMPlatform)
7977
),
8078
crossScalaVersions := Seq(dottyVersion, scala213Version)
8179
)
82-
//.jvmSettings().jsSettings(
83-
// scalaJSUseMainModuleInitializer := true) //, //Testing
80+
.jvmSettings().jsSettings(
81+
scalaJSUseMainModuleInitializer := true, //, //Testing
8482
//Seems to be a bundling issue, copying things manually seems to work
85-
// npmDependencies in Compile += "onnxjs" -> "0.1.8")
83+
npmDependencies in Compile += "onnxjs" -> "0.1.8")
8684

8785
lazy val core = (crossProject(JSPlatform, JVMPlatform)
8886
.crossType(CrossType.Pure) in file("core"))
@@ -115,7 +113,7 @@ lazy val core = (crossProject(JSPlatform, JVMPlatform)
115113
})
116114
)
117115

118-
/*
116+
119117
lazy val docs = (crossProject(JVMPlatform)
120118
.crossType(CrossType.Pure) in file("core-docs")) // new documentation project
121119
.settings(
@@ -129,7 +127,7 @@ lazy val docs = (crossProject(JVMPlatform)
129127
.jvmSettings(
130128
crossScalaVersions := Seq(scala213Version)
131129
)
132-
*/
130+
133131
skip in publish := true
134132
sonatypeProfileName := "com.github.EmergentOrder"
135133

core/src/main/scala/ONNX.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2032,7 +2032,7 @@ package object onnx {
20322032
name: String,
20332033
A: Tensor[T, Tuple3[Tt,Td,S]],
20342034
B: Tensor[T, Tuple3[Tt1,Td1,S1]]
2035-
)(using tt: ValueOf[Tt], td: TensorShapeDenotationOf[Td], s: ShapeOf[Dim0 #: Dim2 #: SNil]): Tensor[T, Tuple3[Tt,Td, Dim0 #: Dim2 #: SNil]] = {
2035+
)(using tt: ValueOf[Tt], td: TensorShapeDenotationOf[Td], s: ShapeOf[Dim0 #: Dim2 #: SNil],vd0:ValueOf[scala.compiletime.S[Dim0]],vd1:ValueOf[scala.compiletime.S[Dim1]], vd2: ValueOf[scala.compiletime.S[Dim2]]): Tensor[T, Tuple3[Tt,Td, Dim0 #: Dim2 #: SNil]] = {
20362036
val map: Map[String, Any] = Map()
20372037
val allInputs = Tuple2(A,B)
20382038
(callOp(name, "MatMul", allInputs, map))

0 commit comments

Comments
 (0)