@@ -242,153 +242,57 @@ function promote_to(::TracedRArray{T,N}, rhs) where {T,N}
242
242
return promote_to (TracedRArray{T,N}, rhs)
243
243
end
244
244
245
- for (jlop, hloop, RT) in (
246
- (:(Base. min), :minimum , :T ),
247
- (:(Base. max), :maximum , :T ),
248
- (:(Base.:+ ), :add , :T ),
249
- (:(Base.:- ), :subtract , :T ),
245
+ for (jlop, hloop) in (
246
+ (:(Base. min), :minimum ),
247
+ (:(Base. max), :maximum ),
248
+ (:(Base.:+ ), :add ),
249
+ (:(Base.:- ), :subtract ),
250
+ (:(Base.:* ), :multiply ),
251
+ (:(Base.:/ ), :divide ),
252
+ (:(Base.:^ ), :power ),
250
253
)
251
254
@eval begin
252
- function $jlop (
253
- @nospecialize (lhs:: TracedRArray{T,N} ), @nospecialize (rhs:: TracedRArray{T2,N} )
254
- ) where {T,T2,N}
255
- commonTy = TracedRArray{Base. promote_type (T, T2),N}
256
- lhs = promote_to (commonTy, lhs)
257
- rhs = promote_to (commonTy, rhs)
258
- return commonTy (
255
+ function $ (jlop)(
256
+ @nospecialize (lhs:: TracedRArray{T,0} ), @nospecialize (rhs:: TracedRArray{T,0} )
257
+ ) where {T}
258
+ return TracedRArray {T,0} (
259
259
(),
260
260
MLIR. IR. result (
261
- MLIR. Dialects. stablehlo.$ hloop (lhs. mlir_data, rhs. mlir_data), 1
261
+ MLIR. Dialects. stablehlo.$ ( hloop) (lhs. mlir_data, rhs. mlir_data), 1
262
262
),
263
- size (lhs ),
263
+ ( ),
264
264
)
265
265
end
266
266
267
- function $jlop (
268
- @nospecialize (lhs:: TracedRArray{T,N} ), @nospecialize (rhs:: TracedRArray{T,N} )
269
- ) where {T,N}
270
- return TracedRArray {$RT,N} (
271
- (),
272
- MLIR. IR. result (
273
- MLIR. Dialects. stablehlo.$ hloop (lhs. mlir_data, rhs. mlir_data), 1
274
- ),
275
- size (lhs),
276
- )
267
+ function $ (jlop)(
268
+ @nospecialize (lhs:: TracedRArray{T1,0} ), @nospecialize (rhs:: TracedRArray{T2,0} )
269
+ ) where {T1,T2}
270
+ commonTy = promote_type (T1, T2)
271
+ lhs = promote_to (commonTy, lhs)
272
+ rhs = promote_to (commonTy, rhs)
273
+ return $ (jlop)(lhs, rhs)
277
274
end
278
275
end
279
276
280
- for otherType in (Number, Any) #= TracedRArray{S,0} where {S} =#
277
+ for otherType in (Number, Any)
281
278
@eval begin
282
- function $jlop (
283
- @nospecialize (lhs:: TracedRArray{T,N } ), @nospecialize (rhs:: $otherType )
284
- ) where {T,N }
279
+ function $ ( jlop) (
280
+ @nospecialize (lhs:: TracedRArray{T,0 } ), @nospecialize (rhs:: $ ( otherType) )
281
+ ) where {T}
285
282
rhs = promote_to (lhs, rhs)
286
- return TracedRArray {$RT,N} (
287
- (),
288
- MLIR. IR. result (
289
- MLIR. Dialects. stablehlo.$ hloop (lhs. mlir_data, rhs. mlir_data), 1
290
- ),
291
- size (lhs),
292
- )
283
+ return $ (jlop)(lhs, rhs)
293
284
end
294
285
295
- function $jlop (
296
- @nospecialize (lhs:: $otherType ), @nospecialize (rhs:: TracedRArray{T,N } )
297
- ) where {T,N }
286
+ function $ ( jlop) (
287
+ @nospecialize (lhs:: $ ( otherType)) , @nospecialize (rhs:: TracedRArray{T,0 } )
288
+ ) where {T}
298
289
lhs = promote_to (rhs, lhs)
299
- return TracedRArray {$RT,N} (
300
- (),
301
- MLIR. IR. result (
302
- MLIR. Dialects. stablehlo.$ hloop (lhs. mlir_data, rhs. mlir_data), 1
303
- ),
304
- size (lhs),
305
- )
290
+ return $ (jlop)(lhs, rhs)
306
291
end
307
292
end
308
293
end
309
294
end
310
295
311
- for (jlop, hloop, RT) in
312
- ((:(Base.:* ), :multiply , :T ), (:(Base.:/ ), :divide , :T ), (:(Base.:^ ), :power , :T ))
313
- @eval begin
314
- function $jlop (
315
- @nospecialize (lhs:: TracedRArray{T,0} ), @nospecialize (rhs:: TracedRArray{T2,0} )
316
- ) where {T,T2}
317
- commonTy = TracedRArray{Base. promote_type (T, T2),0 }
318
- lhs = promote_to (commonTy, lhs)
319
- rhs = promote_to (commonTy, rhs)
320
- return commonTy (
321
- (),
322
- MLIR. IR. result (
323
- MLIR. Dialects. stablehlo.$ hloop (lhs. mlir_data, rhs. mlir_data), 1
324
- ),
325
- (),
326
- )
327
- end
328
-
329
- function $jlop (
330
- @nospecialize (lhs:: TracedRArray{T,0} ), @nospecialize (rhs:: TracedRArray{T,0} )
331
- ) where {T}
332
- return TracedRArray {$RT,0} (
333
- (),
334
- MLIR. IR. result (
335
- MLIR. Dialects. stablehlo.$ hloop (lhs. mlir_data, rhs. mlir_data), 1
336
- ),
337
- (),
338
- )
339
- end
340
-
341
- function $jlop (@nospecialize (lhs:: TracedRArray{T,0} ), @nospecialize (rhs)) where {T}
342
- rhs = promote_to (lhs, rhs)
343
- return TracedRArray {$RT,0} (
344
- (),
345
- MLIR. IR. result (
346
- MLIR. Dialects. stablehlo.$ hloop (lhs. mlir_data, rhs. mlir_data), 1
347
- ),
348
- (),
349
- )
350
- end
351
-
352
- function $jlop (@nospecialize (lhs), @nospecialize (rhs:: TracedRArray{T,0} )) where {T}
353
- lhs = promote_to (rhs, lhs)
354
- return TracedRArray {$RT,0} (
355
- (),
356
- MLIR. IR. result (
357
- MLIR. Dialects. stablehlo.$ hloop (lhs. mlir_data, rhs. mlir_data), 1
358
- ),
359
- (),
360
- )
361
- end
362
-
363
- # Base defines ::AbstractArray / ::Number, so we need this to avoid ambiguity
364
- function $jlop (
365
- @nospecialize (lhs:: TracedRArray{T,0} ), @nospecialize (rhs:: Number )
366
- ) where {T}
367
- rhs = promote_to (lhs, rhs)
368
- return TracedRArray {$RT,0} (
369
- (),
370
- MLIR. IR. result (
371
- MLIR. Dialects. stablehlo.$ hloop (lhs. mlir_data, rhs. mlir_data), 1
372
- ),
373
- (),
374
- )
375
- end
376
-
377
- function $jlop (
378
- @nospecialize (lhs:: Number ), @nospecialize (rhs:: TracedRArray{T,0} )
379
- ) where {T}
380
- lhs = promote_to (rhs, lhs)
381
- return TracedRArray {$RT,0} (
382
- (),
383
- MLIR. IR. result (
384
- MLIR. Dialects. stablehlo.$ hloop (lhs. mlir_data, rhs. mlir_data), 1
385
- ),
386
- (),
387
- )
388
- end
389
- end
390
- end
391
-
392
296
function Base. ifelse (
393
297
@nospecialize (pred:: TracedRArray{Bool,0} ),
394
298
@nospecialize (x:: TracedRArray{T1,0} ),
@@ -424,8 +328,8 @@ for (jlop, hloop) in (
424
328
(:(Base. sqrt), :sqrt ),
425
329
)
426
330
@eval begin
427
- function $jlop (@nospecialize (lhs:: TracedRArray{T,N } )) where {T,N }
428
- return TracedRArray {T,N } (
331
+ function $jlop (@nospecialize (lhs:: TracedRArray{T,0 } )) where {T}
332
+ return TracedRArray {T,0 } (
429
333
(),
430
334
MLIR. IR. result (MLIR. Dialects. stablehlo.$ hloop (lhs. mlir_data), 1 ),
431
335
size (lhs),
0 commit comments