@@ -16,6 +16,10 @@ mutable struct TracedRArray{T,N} <: RArray{T,N}
16
16
end
17
17
end
18
18
19
+ function Base. getindex (a:: TracedRArray{T,0} ) where T
20
+ return a
21
+ end
22
+
19
23
function Base. getindex (a:: TracedRArray{T,N} , index:: Vararg{Integer,N} ) where {T,N}
20
24
@warn (
21
25
""" Performing scalar indexing on task $(current_task ()) .
@@ -43,6 +47,46 @@ and require expensive copies and synchronization each time and therefore should
43
47
return TracedRArray {T,0} ((), res2, ())
44
48
end
45
49
50
+ function Base. getindex (
51
+ a:: TracedRArray{T,N} , indices:: Vararg{Union{Base.AbstractUnitRange,Colon},N}
52
+ ) where {T,N}
53
+ indices = [i isa Colon ? (1 : size (a, idx)) : i for (idx, i) in enumerate (indices)]
54
+ res = MLIR. IR. result (
55
+ MLIR. Dialects. stablehlo. slice (
56
+ a. mlir_data;
57
+ start_indices= MLIR. IR. DenseArrayAttribute ([
58
+ Int64 (first (i) - 1 ) for i in indices
59
+ ]),
60
+ limit_indices= MLIR. IR. DenseArrayAttribute ([Int64 (last (i)) for i in indices]),
61
+ strides= MLIR. IR. DenseArrayAttribute ([Int64 (1 ) for i in indices]),
62
+ ),
63
+ 1 ,
64
+ )
65
+ return TracedRArray {T,N} ((), res, Tuple (length .(indices)))
66
+ end
67
+
68
+ function Base. view (
69
+ a:: TracedRArray{T,N} , indices:: Vararg{Union{Base.AbstractUnitRange,Colon},N}
70
+ ) where {T,N}
71
+ # TODO : Implement before merging the PR
72
+ return error (" view is not supported yet" )
73
+ end
74
+
75
+ function Base. setindex! (
76
+ a:: TracedRArray{T,N} , v, indices:: Vararg{Union{Base.AbstractUnitRange,Colon},N}
77
+ ) where {T,N}
78
+ indices = [(promote_to (TracedRArray{Int, 0 }, i isa Colon ? 1 : first (i))- 1 ). mlir_data for i in indices]
79
+ v = promote_to (TracedRArray{T,N}, v)
80
+ res = MLIR. IR. result (
81
+ MLIR. Dialects. stablehlo. dynamic_update_slice (
82
+ a. mlir_data, v. mlir_data, indices
83
+ ),
84
+ 1 ,
85
+ )
86
+ a. mlir_data = res
87
+ return v
88
+ end
89
+
46
90
Base. size (x:: TracedRArray ) = x. shape
47
91
48
92
Base. copy (A:: TracedRArray{T,N} ) where {T,N} = TracedRArray ((), A. mlir_data, size (A))
118
162
119
163
function promote_to (:: Type{TracedRArray{T,N}} , rhs) where {T,N}
120
164
if isa (rhs, TracedRArray)
165
+ if typeof (rhs) == TracedRArray{T, N}
166
+ return rhs
167
+ end
121
168
return TracedRArray {T,N} (
122
169
(),
123
170
MLIR. IR. result (
@@ -136,10 +183,11 @@ function promote_to(::Type{TracedRArray{T,N}}, rhs) where {T,N}
136
183
)
137
184
return ta
138
185
end
139
- attr = MLIR. IR. DenseElementsAttribute (mlir_type (TracedRArray{T,N}, size (rhs)), rhs)
140
- return TracedRArray {T,N} (
186
+ T0 = eltype (rhs)
187
+ attr = MLIR. IR. DenseElementsAttribute (collect (rhs))
188
+ return promote_to (TracedRArray{T, N}, TracedRArray {T0,length(size(rhs))} (
141
189
(), MLIR. IR. result (MLIR. Dialects. stablehlo. constant (; value= attr), 1 ), size (rhs)
142
- )
190
+ ))
143
191
end
144
192
145
193
function promote_to (lhs:: TracedRArray{T,N} , rhs) where {T,N}
@@ -410,13 +458,13 @@ function elem_apply(f, args::Vararg{Any,Nargs}) where {Nargs}
410
458
return traced2_result
411
459
end
412
460
413
- for (jlop, hloop, hlocomp) in (
414
- (:(Base.:(== )), :compare , " EQ" ),
415
- (:(Base.:(!= )), :compare , " NE" ),
416
- (:(Base.:(>= )), :compare , " GE" ),
417
- (:(Base.:(> )), :compare , " GT" ),
418
- (:(Base.:(<= )), :compare , " LE" ),
419
- (:(Base.:(< )), :compare , " LT" ),
461
+ for (jlop, hloop, hlocomp, merge ) in (
462
+ (:(Base.:(== )), :compare , " EQ" , :all ),
463
+ (:(Base.:(!= )), :compare , " NE" , :any ),
464
+ (:(Base.:(>= )), :compare , " GE" , nothing ),
465
+ (:(Base.:(> )), :compare , " GT" , nothing ),
466
+ (:(Base.:(<= )), :compare , " LE" , nothing ),
467
+ (:(Base.:(< )), :compare , " LT" , nothing ),
420
468
)
421
469
@eval begin
422
470
function elem_apply (
@@ -441,43 +489,50 @@ for (jlop, hloop, hlocomp) in (
441
489
end
442
490
443
491
function elem_apply (
444
- :: typeof ($ jlop), @nospecialize (lhs:: TracedRArray{T,N} ), @nospecialize (rhs)
492
+ fn :: typeof ($ jlop), @nospecialize (lhs:: TracedRArray{T,N} ), @nospecialize (rhs)
445
493
) where {T,N}
446
- rhs = promote_to (lhs, rhs)
447
- return TracedRArray {T,N} (
448
- (),
449
- MLIR. IR. result (
450
- MLIR. Dialects. stablehlo.$ hloop (
451
- lhs. mlir_data,
452
- rhs. mlir_data;
453
- comparison_direction= MLIR. API. stablehloComparisonDirectionAttrGet (
454
- MLIR. IR. context (), $ hlocomp
455
- ),
456
- ),
457
- 1 ,
458
- ),
459
- size (lhs),
460
- )
494
+ elem_apply (fn, lhs, promote_to (lhs, rhs))
461
495
end
462
496
463
497
function elem_apply (
464
498
:: typeof ($ jlop), @nospecialize (lhs), @nospecialize (rhs:: TracedRArray{T,N} )
465
499
) where {T,N}
466
- lhs = promote_to (rhs, lhs)
467
- return TracedRArray {T,N} (
468
- (),
469
- MLIR. IR. result (
470
- MLIR. Dialects. stablehlo.$ hloop (
471
- lhs. mlir_data,
472
- rhs. mlir_data;
473
- comparison_direction= MLIR. API. stablehloComparisonDirectionAttrGet (
474
- MLIR. IR. context (), $ hlocomp
475
- ),
476
- ),
477
- 1 ,
478
- ),
479
- size (lhs),
480
- )
500
+ elem_apply (fn, promote_to (rhs, lhs), rhs)
501
+ end
502
+
503
+ function $jlop (@nospecialize (lhs:: TracedRArray{T,N} ), @nospecialize (rhs)
504
+ ) where {T, N}
505
+ $ jlop (lhs, promote_to (lhs, rhs))
506
+ end
507
+
508
+ function $jlop (@nospecialize (lhs), @nospecialize (rhs:: TracedRArray{T,N} )
509
+ ) where {T, N}
510
+ $ jlop (promote_to (rhs, lhs), rhs)
511
+ end
512
+ end
513
+
514
+ if merge != nothing
515
+ @eval begin
516
+ function $jlop (
517
+ @nospecialize (lhs:: TracedRArray{T,N} ),
518
+ @nospecialize (rhs:: TracedRArray{T,N} )
519
+ ) where {T,N}
520
+ elems = elem_apply ($ jlop, lhs, rhs)
521
+ if N == 0
522
+ elems
523
+ else
524
+ $ merge (elems)
525
+ end
526
+ end
527
+ end
528
+ else
529
+ @eval begin
530
+ function $jlop (
531
+ @nospecialize (lhs:: TracedRArray{T,0} ),
532
+ @nospecialize (rhs:: TracedRArray{T,0} )
533
+ ) where {T}
534
+ elem_apply ($ jlop, lhs, rhs)
535
+ end
481
536
end
482
537
end
483
538
end
0 commit comments