@@ -16,7 +16,7 @@ mutable struct TracedRArray{T,N} <: RArray{T,N}
16
16
end
17
17
end
18
18
19
- function Base. getindex (a:: TracedRArray{T,0} ) where T
19
+ function Base. getindex (a:: TracedRArray{T,0} ) where {T}
20
20
return a
21
21
end
22
22
75
75
function Base. setindex! (
76
76
a:: TracedRArray{T,N} , v, indices:: Vararg{Union{Base.AbstractUnitRange,Colon},N}
77
77
) where {T,N}
78
- indices = [(promote_to (TracedRArray{Int, 0 }, i isa Colon ? 1 : first (i))- 1 ). mlir_data for i in indices]
78
+ indices = [
79
+ (promote_to (TracedRArray{Int,0 }, i isa Colon ? 1 : first (i)) - 1 ). mlir_data for
80
+ i in indices
81
+ ]
79
82
v = promote_to (TracedRArray{T,N}, v)
80
83
res = MLIR. IR. result (
81
- MLIR. Dialects. stablehlo. dynamic_update_slice (
82
- a. mlir_data, v. mlir_data, indices
83
- ),
84
- 1 ,
84
+ MLIR. Dialects. stablehlo. dynamic_update_slice (a. mlir_data, v. mlir_data, indices), 1
85
85
)
86
86
a. mlir_data = res
87
87
return v
162
162
163
163
function promote_to (:: Type{TracedRArray{T,N}} , rhs) where {T,N}
164
164
if isa (rhs, TracedRArray)
165
- if typeof (rhs) == TracedRArray{T, N}
165
+ if typeof (rhs) == TracedRArray{T,N}
166
166
return rhs
167
167
end
168
168
return TracedRArray {T,N} (
@@ -185,9 +185,12 @@ function promote_to(::Type{TracedRArray{T,N}}, rhs) where {T,N}
185
185
end
186
186
T0 = eltype (rhs)
187
187
attr = MLIR. IR. DenseElementsAttribute (collect (rhs))
188
- return promote_to (TracedRArray{T, N}, TracedRArray {T0,length(size(rhs))} (
189
- (), MLIR. IR. result (MLIR. Dialects. stablehlo. constant (; value= attr), 1 ), size (rhs)
190
- ))
188
+ return promote_to (
189
+ TracedRArray{T,N},
190
+ TracedRArray {T0,length(size(rhs))} (
191
+ (), MLIR. IR. result (MLIR. Dialects. stablehlo. constant (; value= attr), 1 ), size (rhs)
192
+ ),
193
+ )
191
194
end
192
195
193
196
function promote_to (lhs:: TracedRArray{T,N} , rhs) where {T,N}
@@ -491,31 +494,32 @@ for (jlop, hloop, hlocomp, merge) in (
491
494
function elem_apply (
492
495
fn:: typeof ($ jlop), @nospecialize (lhs:: TracedRArray{T,N} ), @nospecialize (rhs)
493
496
) where {T,N}
494
- elem_apply (fn, lhs, promote_to (lhs, rhs))
497
+ return elem_apply (fn, lhs, promote_to (lhs, rhs))
495
498
end
496
499
497
500
function elem_apply (
498
501
:: typeof ($ jlop), @nospecialize (lhs), @nospecialize (rhs:: TracedRArray{T,N} )
499
502
) where {T,N}
500
- elem_apply (fn, promote_to (rhs, lhs), rhs)
503
+ return elem_apply (fn, promote_to (rhs, lhs), rhs)
501
504
end
502
505
503
- function $jlop (@nospecialize (lhs:: TracedRArray{T,N} ), @nospecialize (rhs)
504
- ) where {T, N}
505
- $ jlop (lhs, promote_to (lhs, rhs))
506
+ function $jlop (
507
+ @nospecialize (lhs:: TracedRArray{T,N} ), @nospecialize (rhs)
508
+ ) where {T,N}
509
+ return $ jlop (lhs, promote_to (lhs, rhs))
506
510
end
507
511
508
- function $jlop (@nospecialize (lhs), @nospecialize (rhs:: TracedRArray{T,N} )
509
- ) where {T, N}
510
- $ jlop (promote_to (rhs, lhs), rhs)
512
+ function $jlop (
513
+ @nospecialize (lhs), @nospecialize (rhs:: TracedRArray{T,N} )
514
+ ) where {T,N}
515
+ return $ jlop (promote_to (rhs, lhs), rhs)
511
516
end
512
517
end
513
-
518
+
514
519
if merge != nothing
515
520
@eval begin
516
521
function $jlop (
517
- @nospecialize (lhs:: TracedRArray{T,N} ),
518
- @nospecialize (rhs:: TracedRArray{T,N} )
522
+ @nospecialize (lhs:: TracedRArray{T,N} ), @nospecialize (rhs:: TracedRArray{T,N} )
519
523
) where {T,N}
520
524
elems = elem_apply ($ jlop, lhs, rhs)
521
525
if N == 0
@@ -528,10 +532,9 @@ for (jlop, hloop, hlocomp, merge) in (
528
532
else
529
533
@eval begin
530
534
function $jlop (
531
- @nospecialize (lhs:: TracedRArray{T,0} ),
532
- @nospecialize (rhs:: TracedRArray{T,0} )
535
+ @nospecialize (lhs:: TracedRArray{T,0} ), @nospecialize (rhs:: TracedRArray{T,0} )
533
536
) where {T}
534
- elem_apply ($ jlop, lhs, rhs)
537
+ return elem_apply ($ jlop, lhs, rhs)
535
538
end
536
539
end
537
540
end
0 commit comments