Skip to content

Commit 7d750b4

Browse files
Update dependencies and fix device support
1 parent eb5f9b8 commit 7d750b4

File tree

4 files changed

+1984
-1390
lines changed

4 files changed

+1984
-1390
lines changed

python/egglog/exp/array_api_program_gen.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -505,6 +505,6 @@ def bin_op(res: NDArray, op: str) -> Command:
505505
yield rewrite(ndarray_program(abs(x))).to((Program("np.abs(") + ndarray_program(x) + ")").assign())
506506

507507
# asarray
508-
yield rewrite(ndarray_program(asarray(x, odtype))).to(
508+
yield rewrite(ndarray_program(asarray(x, odtype, OptionalBool.none, optional_device_))).to(
509509
Program("np.asarray(") + ndarray_program(x) + ", " + optional_dtype_program(odtype) + ")"
510510
)

python/tests/__snapshots__/test_array_api/test_jit[lda][expr].py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,14 @@
4040
Value.float(Float.rational(BigRat(BigInt.from_string("1"), BigInt.from_string("1"))))
4141
)
4242
_TupleNDArray_1 = svd(
43-
sqrt(asarray(NDArray.scalar(Value.float(Float.rational(BigRat(BigInt.from_string("1"), BigInt.from_string("147"))))), OptionalDType.some(DType.float64)))
43+
sqrt(
44+
asarray(
45+
NDArray.scalar(Value.float(Float.rational(BigRat(BigInt.from_string("1"), BigInt.from_string("147"))))),
46+
OptionalDType.some(DType.float64),
47+
OptionalBool.none,
48+
OptionalDevice.some(_NDArray_1.device),
49+
)
50+
)
4451
* (_NDArray_8 / _NDArray_11),
4552
Boolean(False),
4653
)

python/tests/__snapshots__/test_array_api/test_jit[lda][initial_expr].py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,14 @@
4545
Value.float(Float.rational(BigRat(BigInt.from_string("1"), BigInt.from_string("1"))))
4646
)
4747
_TupleNDArray_1 = svd(
48-
sqrt(asarray(NDArray.scalar(Value.float(Float.rational(BigRat(BigInt.from_string("1"), BigInt.from_string("147"))))), OptionalDType.some(DType.float64)))
48+
sqrt(
49+
asarray(
50+
NDArray.scalar(Value.float(Float.rational(BigRat(BigInt.from_string("1"), BigInt.from_string("147"))))),
51+
OptionalDType.some(DType.float64),
52+
OptionalBool.none,
53+
OptionalDevice.some(_NDArray_1.device),
54+
)
55+
)
4956
* (_NDArray_5 / _NDArray_6),
5057
Boolean(False),
5158
)

0 commit comments

Comments
 (0)