Skip to content

Commit ef1afce

Browse files
Merge pull request #272 from egraphs-good/builtins
Easily create pure egglog output of jitted function
2 parents e8be112 + 18f4a14 commit ef1afce

File tree

9 files changed

+113
-68
lines changed

9 files changed

+113
-68
lines changed

python/egglog/egraph.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1241,10 +1241,10 @@ def register(
12411241

12421242
def _register_commands(self, cmds: list[Command]) -> None:
12431243
self._add_decls(*cmds)
1244-
egg_cmds = list(map(self._command_to_egg, cmds))
1244+
egg_cmds = [egg_cmd for cmd in cmds if (egg_cmd := self._command_to_egg(cmd)) is not None]
12451245
self._egraph.run_program(*egg_cmds)
12461246

1247-
def _command_to_egg(self, cmd: Command) -> bindings._Command:
1247+
def _command_to_egg(self, cmd: Command) -> bindings._Command | None:
12481248
ruleset_name = ""
12491249
cmd_decl: CommandDecl
12501250
match cmd:

python/egglog/egraph_state.py

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import re
88
from collections import defaultdict
99
from dataclasses import dataclass, field
10-
from typing import TYPE_CHECKING, overload
10+
from typing import TYPE_CHECKING, Literal, overload
1111

1212
from typing_extensions import assert_never
1313

@@ -116,7 +116,8 @@ def ruleset_to_egg(self, name: str) -> None:
116116
if rule in added_rules:
117117
continue
118118
cmd = self.command_to_egg(rule, name)
119-
self.egraph.run_program(cmd)
119+
if cmd is not None:
120+
self.egraph.run_program(cmd)
120121
added_rules.add(rule)
121122
case CombinedRulesetDecl(rulesets):
122123
if name in self.rulesets:
@@ -126,10 +127,13 @@ def ruleset_to_egg(self, name: str) -> None:
126127
self.ruleset_to_egg(ruleset)
127128
self.egraph.run_program(bindings.UnstableCombinedRuleset(name, list(rulesets)))
128129

129-
def command_to_egg(self, cmd: CommandDecl, ruleset: str) -> bindings._Command:
130+
def command_to_egg(self, cmd: CommandDecl, ruleset: str) -> bindings._Command | None:
130131
match cmd:
131132
case ActionCommandDecl(action):
132-
return bindings.ActionCommand(self.action_to_egg(action, expr_to_let=True))
133+
action_egg = self.action_to_egg(action, expr_to_let=True)
134+
if not action_egg:
135+
return None
136+
return bindings.ActionCommand(action_egg)
133137
case RewriteDecl(tp, lhs, rhs, conditions) | BiRewriteDecl(tp, lhs, rhs, conditions):
134138
self.type_ref_to_egg(tp)
135139
rewrite = bindings.Rewrite(
@@ -166,7 +170,13 @@ def command_to_egg(self, cmd: CommandDecl, ruleset: str) -> bindings._Command:
166170
case _:
167171
assert_never(cmd)
168172

169-
def action_to_egg(self, action: ActionDecl, expr_to_let: bool = False) -> bindings._Action:
173+
@overload
174+
def action_to_egg(self, action: ActionDecl) -> bindings._Action: ...
175+
176+
@overload
177+
def action_to_egg(self, action: ActionDecl, expr_to_let: Literal[True] = ...) -> bindings._Action | None: ...
178+
179+
def action_to_egg(self, action: ActionDecl, expr_to_let: bool = False) -> bindings._Action | None: # noqa: C901, PLR0911, PLR0912
170180
match action:
171181
case LetDecl(name, typed_expr):
172182
var_decl = VarDecl(name, True)
@@ -179,7 +189,11 @@ def action_to_egg(self, action: ActionDecl, expr_to_let: bool = False) -> bindin
179189
return bindings.Set(span(), call_.name, call_.args, self._expr_to_egg(rhs))
180190
case ExprActionDecl(typed_expr):
181191
if expr_to_let:
182-
typed_expr = self._transform_let(typed_expr)
192+
maybe_typed_expr = self._transform_let(typed_expr)
193+
if maybe_typed_expr:
194+
typed_expr = maybe_typed_expr
195+
else:
196+
return None
183197
return bindings.Expr_(span(), self.typed_expr_to_egg(typed_expr))
184198
case ChangeDecl(tp, call, change):
185199
self.type_ref_to_egg(tp)
@@ -351,13 +365,13 @@ def typed_expr_to_egg(self, typed_expr_decl: TypedExprDecl, transform_let: bool
351365
self.type_ref_to_egg(typed_expr_decl.tp)
352366
return self._expr_to_egg(typed_expr_decl.expr)
353367

354-
def _transform_let(self, typed_expr: TypedExprDecl) -> TypedExprDecl:
368+
def _transform_let(self, typed_expr: TypedExprDecl) -> TypedExprDecl | None:
355369
"""
356370
Rewrites this expression as a let binding if it's not already a let binding.
357371
"""
358372
var_decl = VarDecl(f"__expr_{hash(typed_expr)}", True)
359373
if var_decl in self.expr_to_egg_cache:
360-
return TypedExprDecl(typed_expr.tp, var_decl)
374+
return None
361375
var_egg = self._expr_to_egg(var_decl)
362376
cmd = bindings.ActionCommand(bindings.Let(span(), var_egg.name, self.typed_expr_to_egg(typed_expr)))
363377
try:
@@ -367,7 +381,7 @@ def _transform_let(self, typed_expr: TypedExprDecl) -> TypedExprDecl:
367381
return typed_expr
368382
self.expr_to_egg_cache[typed_expr.expr] = var_egg
369383
self.expr_to_egg_cache[var_decl] = var_egg
370-
return TypedExprDecl(typed_expr.tp, var_decl)
384+
return None
371385

372386
@overload
373387
def _expr_to_egg(self, expr_decl: CallDecl) -> bindings.Call: ...

python/egglog/exp/array_api.py

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -290,8 +290,8 @@ def check_index(length: IntLike, idx: IntLike) -> Int:
290290
"""
291291
Returns the index if 0 <= idx < length, otherwise returns Int.NEVER
292292
"""
293-
length = cast(Int, length)
294-
idx = cast(Int, idx)
293+
length = cast("Int", length)
294+
idx = cast("Int", idx)
295295
return Int.if_(((idx >= 0) & (idx < length)), idx, Int.NEVER)
296296

297297

@@ -336,7 +336,7 @@ def abs(self) -> Float: ...
336336

337337
@method(cost=2)
338338
@classmethod
339-
def rational(cls, r: Rational) -> Float: ...
339+
def rational(cls, r: BigRat) -> Float: ...
340340

341341
@classmethod
342342
def from_int(cls, i: IntLike) -> Float: ...
@@ -362,15 +362,15 @@ def __eq__(self, other: FloatLike) -> Boolean: ... # type: ignore[override]
362362

363363

364364
@array_api_ruleset.register
365-
def _float(fl: Float, f: f64, f2: f64, i: i64, r: Rational, r1: Rational):
365+
def _float(fl: Float, f: f64, f2: f64, i: i64, r: BigRat, r1: BigRat):
366366
return [
367367
rule(eq(fl).to(Float(f))).then(set_(fl.to_f64).to(f)),
368368
rewrite(Float(f).abs()).to(Float(f), f >= 0.0),
369369
rewrite(Float(f).abs()).to(Float(-f), f < 0.0),
370370
# Convert from float to rationl, if its a whole number i.e. can be converted to int
371-
rewrite(Float(f)).to(Float.rational(Rational(f.to_i64(), 1)), eq(f64.from_i64(f.to_i64())).to(f)),
371+
rewrite(Float(f)).to(Float.rational(BigRat(f.to_i64(), 1)), eq(f64.from_i64(f.to_i64())).to(f)),
372372
# always convert from int to rational
373-
rewrite(Float.from_int(Int(i))).to(Float.rational(Rational(i, 1))),
373+
rewrite(Float.from_int(Int(i))).to(Float.rational(BigRat(i, 1))),
374374
rewrite(Float(f) + Float(f2)).to(Float(f + f2)),
375375
rewrite(Float(f) - Float(f2)).to(Float(f - f2)),
376376
rewrite(Float(f) * Float(f2)).to(Float(f * f2)),
@@ -417,7 +417,7 @@ def range(cls, stop: IntLike) -> TupleInt:
417417
def from_vec(cls, vec: VecLike[Int, IntLike]) -> TupleInt: ...
418418

419419
def __add__(self, other: TupleIntLike) -> TupleInt:
420-
other = cast(TupleInt, other)
420+
other = cast("TupleInt", other)
421421
return TupleInt(
422422
self.length() + other.length(), lambda i: Int.if_(i < self.length(), self[i], other[i - self.length()])
423423
)
@@ -475,14 +475,14 @@ def select(self, indices: TupleIntLike) -> TupleInt:
475475
"""
476476
Return a new tuple with the elements at the given indices
477477
"""
478-
indices = cast(TupleInt, indices)
478+
indices = cast("TupleInt", indices)
479479
return indices.map(lambda i: self[i])
480480

481481
def deselect(self, indices: TupleIntLike) -> TupleInt:
482482
"""
483483
Return a new tuple with the elements not at the given indices
484484
"""
485-
indices = cast(TupleInt, indices)
485+
indices = cast("TupleInt", indices)
486486
return TupleInt.range(self.length()).filter(lambda i: ~indices.contains(i)).map(lambda i: self[i])
487487

488488

@@ -554,7 +554,7 @@ def __init__(self, length: IntLike, idx_fn: Callable[[Int], TupleInt]) -> None:
554554
@method(subsume=True)
555555
@classmethod
556556
def single(cls, i: TupleIntLike) -> TupleTupleInt:
557-
i = cast(TupleInt, i)
557+
i = cast("TupleInt", i)
558558
return TupleTupleInt(1, lambda _: i)
559559

560560
@method(subsume=True)
@@ -564,7 +564,7 @@ def from_vec(cls, vec: Vec[TupleInt]) -> TupleTupleInt: ...
564564
def append(self, i: TupleIntLike) -> TupleTupleInt: ...
565565

566566
def __add__(self, other: TupleTupleIntLike) -> TupleTupleInt:
567-
other = cast(TupleTupleInt, other)
567+
other = cast("TupleTupleInt", other)
568568
return TupleTupleInt(
569569
self.length() + other.length(),
570570
lambda i: TupleInt.if_(i < self.length(), self[i], other[i - self.length()]),
@@ -840,7 +840,7 @@ def _value(i: Int, f: Float, b: Boolean, v: Value, v1: Value, i1: Int, f1: Float
840840

841841
yield rewrite(Value.float(f).sqrt()).to(Value.float(f ** (0.5)))
842842

843-
yield rewrite(Value.float(Float.rational(Rational(0, 1))) + v).to(v)
843+
yield rewrite(Value.float(Float.rational(BigRat(0, 1))) + v).to(v)
844844

845845
yield rewrite(Value.if_(TRUE, v, v1)).to(v)
846846
yield rewrite(Value.if_(FALSE, v, v1)).to(v1)
@@ -862,7 +862,7 @@ def append(self, i: ValueLike) -> TupleValue: ...
862862
def from_vec(cls, vec: Vec[Value]) -> TupleValue: ...
863863

864864
def __add__(self, other: TupleValueLike) -> TupleValue:
865-
other = cast(TupleValue, other)
865+
other = cast("TupleValue", other)
866866
return TupleValue(
867867
self.length() + other.length(),
868868
lambda i: Value.if_(i < self.length(), self[i], other[i - self.length()]),
@@ -875,13 +875,13 @@ def __getitem__(self, i: Int) -> Value: ...
875875
def foldl_boolean(self, f: Callable[[Boolean, Value], Boolean], init: BooleanLike) -> Boolean: ...
876876

877877
def contains(self, value: ValueLike) -> Boolean:
878-
value = cast(Value, value)
878+
value = cast("Value", value)
879879
return self.foldl_boolean(lambda acc, j: acc | (value == j), FALSE)
880880

881881
@method(subsume=True)
882882
@classmethod
883883
def from_tuple_int(cls, ti: TupleIntLike) -> TupleValue:
884-
ti = cast(TupleInt, ti)
884+
ti = cast("TupleInt", ti)
885885
return TupleValue(ti.length(), lambda i: Value.int(ti[i]))
886886

887887

@@ -1259,7 +1259,7 @@ def append(self, i: NDArrayLike) -> TupleNDArray: ...
12591259
def from_vec(cls, vec: Vec[NDArray]) -> TupleNDArray: ...
12601260

12611261
def __add__(self, other: TupleNDArrayLike) -> TupleNDArray:
1262-
other = cast(TupleNDArray, other)
1262+
other = cast("TupleNDArray", other)
12631263
return TupleNDArray(
12641264
self.length() + other.length(),
12651265
lambda i: NDArray.if_(i < self.length(), self[i], other[i - self.length()]),
@@ -1632,7 +1632,7 @@ def ndindex(shape: TupleIntLike) -> TupleTupleInt:
16321632
"""
16331633
https://numpy.org/doc/stable/reference/generated/numpy.ndindex.html
16341634
"""
1635-
shape = cast(TupleInt, shape)
1635+
shape = cast("TupleInt", shape)
16361636
return shape.map_tuple_int(TupleInt.range).product()
16371637

16381638

python/egglog/exp/array_api_jit.py

Lines changed: 26 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,14 @@
22
from collections.abc import Callable
33
from typing import TypeVar, cast
44

5+
import numpy as np
6+
57
from egglog import EGraph, try_evaling
68
from egglog.exp.array_api import NDArray
79
from egglog.exp.array_api_numba import array_api_numba_schedule
8-
from egglog.exp.array_api_program_gen import array_api_program_gen_schedule, ndarray_function_two
10+
from egglog.exp.array_api_program_gen import EvalProgram, array_api_program_gen_schedule, ndarray_function_two_program
11+
12+
from .program_gen import Program
913

1014
X = TypeVar("X", bound=Callable)
1115

@@ -14,15 +18,27 @@ def jit(fn: X) -> X:
1418
"""
1519
Jit compiles a function
1620
"""
17-
sig = inspect.signature(fn)
18-
arg1, arg2 = sig.parameters.keys()
19-
egraph = EGraph()
21+
egraph, res, res_optimized, program = function_to_program(fn, save_egglog_string=False)
22+
fn_program = EvalProgram(program, {"np": np})
2023
with egraph.set_current():
21-
res = fn(NDArray.var(arg1), NDArray.var(arg2))
22-
res_optimized = egraph.simplify(res, array_api_numba_schedule)
23-
24-
fn_program = ndarray_function_two(res_optimized, NDArray.var(arg1), NDArray.var(arg2))
25-
fn = try_evaling(array_api_program_gen_schedule, fn_program, fn_program.as_py_object)
24+
fn = cast("X", try_evaling(array_api_program_gen_schedule, fn_program, fn_program.as_py_object))
2625
fn.initial_expr = res # type: ignore[attr-defined]
2726
fn.expr = res_optimized # type: ignore[attr-defined]
28-
return cast(X, fn)
27+
return fn
28+
29+
30+
def function_to_program(fn: Callable, save_egglog_string: bool) -> tuple[EGraph, NDArray, NDArray, Program]:
31+
sig = inspect.signature(fn)
32+
arg1, arg2 = sig.parameters.keys()
33+
egraph = EGraph(save_egglog_string=save_egglog_string)
34+
with egraph:
35+
with egraph.set_current():
36+
res = fn(NDArray.var(arg1), NDArray.var(arg2))
37+
res_optimized = egraph.simplify(res, array_api_numba_schedule)
38+
39+
return (
40+
egraph,
41+
res,
42+
res_optimized,
43+
ndarray_function_two_program(res_optimized, NDArray.var(arg1), NDArray.var(arg2)),
44+
)

python/egglog/exp/array_api_program_gen.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,13 @@
1414
array_api_program_gen_ruleset = ruleset(name="array_api_program_gen_ruleset")
1515
array_api_program_gen_eval_ruleset = ruleset(name="array_api_program_gen_eval_ruleset")
1616

17-
array_api_program_gen_schedule = (
17+
array_api_program_gen_combined_ruleset = (
1818
array_api_program_gen_ruleset
1919
| program_gen_ruleset
2020
| array_api_program_gen_eval_ruleset
21-
| eval_program_rulseset
2221
| array_api_vec_to_cons_ruleset
23-
).saturate()
22+
)
23+
array_api_program_gen_schedule = (array_api_program_gen_combined_ruleset | eval_program_rulseset).saturate()
2424

2525

2626
@function
@@ -116,7 +116,7 @@ def float_program(x: Float) -> Program: ...
116116

117117

118118
@array_api_program_gen_ruleset.register
119-
def _float_program(f: Float, g: Float, f64_: f64, i: Int, r: Rational):
119+
def _float_program(f: Float, g: Float, f64_: f64, i: Int, r: BigRat):
120120
yield rewrite(float_program(Float(f64_))).to(Program(f64_.to_string()))
121121
yield rewrite(float_program(f.abs())).to(Program("np.abs(") + float_program(f) + ")")
122122
yield rewrite(float_program(Float.from_int(i))).to(int_program(i))
@@ -126,10 +126,10 @@ def _float_program(f: Float, g: Float, f64_: f64, i: Int, r: Rational):
126126
yield rewrite(float_program(f / g)).to(Program("(") + float_program(f) + " / " + float_program(g) + ")")
127127
yield rewrite(float_program(Float.rational(r))).to(
128128
Program("float(") + Program(r.numer.to_string()) + " / " + Program(r.denom.to_string()) + ")",
129-
ne(r.denom).to(i64(1)),
129+
ne(r.denom).to(BigInt(1)),
130130
)
131131
yield rewrite(float_program(Float.rational(r))).to(
132-
Program("float(") + Program(r.numer.to_string()) + ")", eq(r.denom).to(i64(1))
132+
Program("float(") + Program(r.numer.to_string()) + ")", eq(r.denom).to(BigInt(1))
133133
)
134134

135135

python/tests/__snapshots__/test_array_api/test_jit[lda][expr].py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
)
1818
),
1919
DType.float64,
20-
) / NDArray.scalar(Value.float(Float.rational(Rational(150, 1))))
20+
) / NDArray.scalar(Value.float(Float.rational(BigRat(BigInt.from_string("150"), BigInt.from_string("1")))))
2121
_NDArray_4 = zeros(TupleInt.from_vec(Vec[Int](Int(3), Int(4))), OptionalDType.some(DType.float64), OptionalDevice.some(_NDArray_1.device))
2222
_MultiAxisIndexKeyItem_1 = MultiAxisIndexKeyItem.slice(Slice())
2323
_IndexKey_1 = IndexKey.multi_axis(MultiAxisIndexKey.from_vec(Vec[MultiAxisIndexKeyItem](MultiAxisIndexKeyItem.int(Int(0)), _MultiAxisIndexKeyItem_1)))
@@ -36,15 +36,24 @@
3636
_NDArray_9 = square(_NDArray_8 - expand_dims(sum(_NDArray_8, _OptionalIntOrTuple_1) / NDArray.scalar(Value.int(_NDArray_8.shape[Int(0)]))))
3737
_NDArray_10 = sqrt(sum(_NDArray_9, _OptionalIntOrTuple_1) / NDArray.scalar(Value.int(_NDArray_9.shape[Int(0)])))
3838
_NDArray_11 = copy(_NDArray_10)
39-
_NDArray_11[IndexKey.ndarray(_NDArray_10 == NDArray.scalar(Value.int(Int(0))))] = NDArray.scalar(Value.float(Float.rational(Rational(1, 1))))
40-
_TupleNDArray_1 = svd(sqrt(asarray(NDArray.scalar(Value.float(Float.rational(Rational(1, 147)))), OptionalDType.some(DType.float64))) * (_NDArray_8 / _NDArray_11), Boolean(False))
39+
_NDArray_11[IndexKey.ndarray(_NDArray_10 == NDArray.scalar(Value.int(Int(0))))] = NDArray.scalar(
40+
Value.float(Float.rational(BigRat(BigInt.from_string("1"), BigInt.from_string("1"))))
41+
)
42+
_TupleNDArray_1 = svd(
43+
sqrt(asarray(NDArray.scalar(Value.float(Float.rational(BigRat(BigInt.from_string("1"), BigInt.from_string("147"))))), OptionalDType.some(DType.float64)))
44+
* (_NDArray_8 / _NDArray_11),
45+
Boolean(False),
46+
)
4147
_Slice_1 = Slice(OptionalInt.none, OptionalInt.some(sum(astype(_TupleNDArray_1[Int(1)] > NDArray.scalar(Value.float(Float(0.0001))), DType.int32)).to_value().to_int))
4248
_NDArray_12 = (
4349
_TupleNDArray_1[Int(2)][IndexKey.multi_axis(MultiAxisIndexKey.from_vec(Vec[MultiAxisIndexKeyItem](MultiAxisIndexKeyItem.slice(_Slice_1), _MultiAxisIndexKeyItem_1)))]
4450
/ _NDArray_11
4551
).T / _TupleNDArray_1[Int(1)][IndexKey.slice(_Slice_1)]
4652
_TupleNDArray_2 = svd(
47-
(sqrt((NDArray.scalar(Value.int(Int(150))) * _NDArray_3) * NDArray.scalar(Value.float(Float.rational(Rational(1, 2))))) * (_NDArray_4 - (_NDArray_3 @ _NDArray_4)).T).T
53+
(
54+
sqrt((NDArray.scalar(Value.int(Int(150))) * _NDArray_3) * NDArray.scalar(Value.float(Float.rational(BigRat(BigInt.from_string("1"), BigInt.from_string("2"))))))
55+
* (_NDArray_4 - (_NDArray_3 @ _NDArray_4)).T
56+
).T
4857
@ _NDArray_12,
4958
Boolean(False),
5059
)

python/tests/__snapshots__/test_array_api/test_jit[lda][initial_expr].py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,14 @@
4141
OptionalInt.some(Int(0)),
4242
)
4343
_NDArray_6 = std(_NDArray_5, _OptionalIntOrTuple_1)
44-
_NDArray_6[IndexKey.ndarray(std(_NDArray_5, _OptionalIntOrTuple_1) == NDArray.scalar(Value.int(Int(0))))] = NDArray.scalar(Value.float(Float.rational(Rational(1, 1))))
45-
_TupleNDArray_1 = svd(sqrt(asarray(NDArray.scalar(Value.float(Float.rational(Rational(1, 147)))), OptionalDType.some(DType.float64))) * (_NDArray_5 / _NDArray_6), Boolean(False))
44+
_NDArray_6[IndexKey.ndarray(std(_NDArray_5, _OptionalIntOrTuple_1) == NDArray.scalar(Value.int(Int(0))))] = NDArray.scalar(
45+
Value.float(Float.rational(BigRat(BigInt.from_string("1"), BigInt.from_string("1"))))
46+
)
47+
_TupleNDArray_1 = svd(
48+
sqrt(asarray(NDArray.scalar(Value.float(Float.rational(BigRat(BigInt.from_string("1"), BigInt.from_string("147"))))), OptionalDType.some(DType.float64)))
49+
* (_NDArray_5 / _NDArray_6),
50+
Boolean(False),
51+
)
4652
_Slice_1 = Slice(OptionalInt.none, OptionalInt.some(sum(astype(_TupleNDArray_1[Int(1)] > NDArray.scalar(Value.float(Float(0.0001))), DType.int32)).to_value().to_int))
4753
_NDArray_7 = asarray(reshape(asarray(_NDArray_2), TupleInt.from_vec(Vec[Int](Int(-1)))))
4854
_NDArray_8 = unique_values(concat(TupleNDArray.from_vec(Vec[NDArray](unique_values(asarray(_NDArray_7))))))
@@ -61,9 +67,12 @@
6167
)
6268
_NDArray_10 = copy(_NDArray_9)
6369
_NDArray_10[IndexKey.ndarray(_NDArray_9 == NDArray.scalar(Value.int(Int(0))))] = NDArray.scalar(Value.float(Float(1.0)))
64-
_NDArray_11 = astype(unique_counts(_NDArray_2)[Int(1)], DType.float64) / NDArray.scalar(Value.float(Float.rational(Rational(150, 1))))
70+
_NDArray_11 = astype(unique_counts(_NDArray_2)[Int(1)], DType.float64) / NDArray.scalar(Value.float(Float.rational(BigRat(BigInt.from_string("150"), BigInt.from_string("1")))))
6571
_TupleNDArray_2 = svd(
66-
(sqrt((NDArray.scalar(Value.int(Int(150))) * _NDArray_11) * NDArray.scalar(Value.float(Float.rational(Rational(1, 2))))) * (_NDArray_4 - (_NDArray_11 @ _NDArray_4)).T).T
72+
(
73+
sqrt((NDArray.scalar(Value.int(Int(150))) * _NDArray_11) * NDArray.scalar(Value.float(Float.rational(BigRat(BigInt.from_string("1"), BigInt.from_string("2"))))))
74+
* (_NDArray_4 - (_NDArray_11 @ _NDArray_4)).T
75+
).T
6776
@ (
6877
(
6978
_TupleNDArray_1[Int(2)][IndexKey.multi_axis(MultiAxisIndexKey.from_vec(Vec[MultiAxisIndexKeyItem](MultiAxisIndexKeyItem.slice(_Slice_1), _MultiAxisIndexKeyItem_1)))]

0 commit comments

Comments
 (0)