Skip to content

Commit c9cc50b

Browse files
committed
refactor: remove old non-generic broadcasting code
1 parent 57d69a3 commit c9cc50b

File tree

1 file changed

+32
-128
lines changed

1 file changed

+32
-128
lines changed

src/TracedRArray.jl

Lines changed: 32 additions & 128 deletions
Original file line numberDiff line numberDiff line change
@@ -242,153 +242,57 @@ function promote_to(::TracedRArray{T,N}, rhs) where {T,N}
242242
return promote_to(TracedRArray{T,N}, rhs)
243243
end
244244

245-
for (jlop, hloop, RT) in (
246-
(:(Base.min), :minimum, :T),
247-
(:(Base.max), :maximum, :T),
248-
(:(Base.:+), :add, :T),
249-
(:(Base.:-), :subtract, :T),
245+
for (jlop, hloop) in (
246+
(:(Base.min), :minimum),
247+
(:(Base.max), :maximum),
248+
(:(Base.:+), :add),
249+
(:(Base.:-), :subtract),
250+
(:(Base.:*), :multiply),
251+
(:(Base.:/), :divide),
252+
(:(Base.:^), :power),
250253
)
251254
@eval begin
252-
function $jlop(
253-
@nospecialize(lhs::TracedRArray{T,N}), @nospecialize(rhs::TracedRArray{T2,N})
254-
) where {T,T2,N}
255-
commonTy = TracedRArray{Base.promote_type(T, T2),N}
256-
lhs = promote_to(commonTy, lhs)
257-
rhs = promote_to(commonTy, rhs)
258-
return commonTy(
255+
function $(jlop)(
256+
@nospecialize(lhs::TracedRArray{T,0}), @nospecialize(rhs::TracedRArray{T,0})
257+
) where {T}
258+
return TracedRArray{T,0}(
259259
(),
260260
MLIR.IR.result(
261-
MLIR.Dialects.stablehlo.$hloop(lhs.mlir_data, rhs.mlir_data), 1
261+
MLIR.Dialects.stablehlo.$(hloop)(lhs.mlir_data, rhs.mlir_data), 1
262262
),
263-
size(lhs),
263+
(),
264264
)
265265
end
266266

267-
function $jlop(
268-
@nospecialize(lhs::TracedRArray{T,N}), @nospecialize(rhs::TracedRArray{T,N})
269-
) where {T,N}
270-
return TracedRArray{$RT,N}(
271-
(),
272-
MLIR.IR.result(
273-
MLIR.Dialects.stablehlo.$hloop(lhs.mlir_data, rhs.mlir_data), 1
274-
),
275-
size(lhs),
276-
)
267+
function $(jlop)(
268+
@nospecialize(lhs::TracedRArray{T1,0}), @nospecialize(rhs::TracedRArray{T2,0})
269+
) where {T1,T2}
270+
commonTy = promote_type(T1, T2)
271+
lhs = promote_to(commonTy, lhs)
272+
rhs = promote_to(commonTy, rhs)
273+
return $(jlop)(lhs, rhs)
277274
end
278275
end
279276

280-
for otherType in (Number, Any) #=TracedRArray{S,0} where {S}=#
277+
for otherType in (Number, Any)
281278
@eval begin
282-
function $jlop(
283-
@nospecialize(lhs::TracedRArray{T,N}), @nospecialize(rhs::$otherType)
284-
) where {T,N}
279+
function $(jlop)(
280+
@nospecialize(lhs::TracedRArray{T,0}), @nospecialize(rhs::$(otherType))
281+
) where {T}
285282
rhs = promote_to(lhs, rhs)
286-
return TracedRArray{$RT,N}(
287-
(),
288-
MLIR.IR.result(
289-
MLIR.Dialects.stablehlo.$hloop(lhs.mlir_data, rhs.mlir_data), 1
290-
),
291-
size(lhs),
292-
)
283+
return $(jlop)(lhs, rhs)
293284
end
294285

295-
function $jlop(
296-
@nospecialize(lhs::$otherType), @nospecialize(rhs::TracedRArray{T,N})
297-
) where {T,N}
286+
function $(jlop)(
287+
@nospecialize(lhs::$(otherType)), @nospecialize(rhs::TracedRArray{T,0})
288+
) where {T}
298289
lhs = promote_to(rhs, lhs)
299-
return TracedRArray{$RT,N}(
300-
(),
301-
MLIR.IR.result(
302-
MLIR.Dialects.stablehlo.$hloop(lhs.mlir_data, rhs.mlir_data), 1
303-
),
304-
size(lhs),
305-
)
290+
return $(jlop)(lhs, rhs)
306291
end
307292
end
308293
end
309294
end
310295

311-
for (jlop, hloop, RT) in
312-
((:(Base.:*), :multiply, :T), (:(Base.:/), :divide, :T), (:(Base.:^), :power, :T))
313-
@eval begin
314-
function $jlop(
315-
@nospecialize(lhs::TracedRArray{T,0}), @nospecialize(rhs::TracedRArray{T2,0})
316-
) where {T,T2}
317-
commonTy = TracedRArray{Base.promote_type(T, T2),0}
318-
lhs = promote_to(commonTy, lhs)
319-
rhs = promote_to(commonTy, rhs)
320-
return commonTy(
321-
(),
322-
MLIR.IR.result(
323-
MLIR.Dialects.stablehlo.$hloop(lhs.mlir_data, rhs.mlir_data), 1
324-
),
325-
(),
326-
)
327-
end
328-
329-
function $jlop(
330-
@nospecialize(lhs::TracedRArray{T,0}), @nospecialize(rhs::TracedRArray{T,0})
331-
) where {T}
332-
return TracedRArray{$RT,0}(
333-
(),
334-
MLIR.IR.result(
335-
MLIR.Dialects.stablehlo.$hloop(lhs.mlir_data, rhs.mlir_data), 1
336-
),
337-
(),
338-
)
339-
end
340-
341-
function $jlop(@nospecialize(lhs::TracedRArray{T,0}), @nospecialize(rhs)) where {T}
342-
rhs = promote_to(lhs, rhs)
343-
return TracedRArray{$RT,0}(
344-
(),
345-
MLIR.IR.result(
346-
MLIR.Dialects.stablehlo.$hloop(lhs.mlir_data, rhs.mlir_data), 1
347-
),
348-
(),
349-
)
350-
end
351-
352-
function $jlop(@nospecialize(lhs), @nospecialize(rhs::TracedRArray{T,0})) where {T}
353-
lhs = promote_to(rhs, lhs)
354-
return TracedRArray{$RT,0}(
355-
(),
356-
MLIR.IR.result(
357-
MLIR.Dialects.stablehlo.$hloop(lhs.mlir_data, rhs.mlir_data), 1
358-
),
359-
(),
360-
)
361-
end
362-
363-
# Base defines ::AbstractArray / ::Number, so we need this to avoid ambiguity
364-
function $jlop(
365-
@nospecialize(lhs::TracedRArray{T,0}), @nospecialize(rhs::Number)
366-
) where {T}
367-
rhs = promote_to(lhs, rhs)
368-
return TracedRArray{$RT,0}(
369-
(),
370-
MLIR.IR.result(
371-
MLIR.Dialects.stablehlo.$hloop(lhs.mlir_data, rhs.mlir_data), 1
372-
),
373-
(),
374-
)
375-
end
376-
377-
function $jlop(
378-
@nospecialize(lhs::Number), @nospecialize(rhs::TracedRArray{T,0})
379-
) where {T}
380-
lhs = promote_to(rhs, lhs)
381-
return TracedRArray{$RT,0}(
382-
(),
383-
MLIR.IR.result(
384-
MLIR.Dialects.stablehlo.$hloop(lhs.mlir_data, rhs.mlir_data), 1
385-
),
386-
(),
387-
)
388-
end
389-
end
390-
end
391-
392296
function Base.ifelse(
393297
@nospecialize(pred::TracedRArray{Bool,0}),
394298
@nospecialize(x::TracedRArray{T1,0}),
@@ -424,8 +328,8 @@ for (jlop, hloop) in (
424328
(:(Base.sqrt), :sqrt),
425329
)
426330
@eval begin
427-
function $jlop(@nospecialize(lhs::TracedRArray{T,N})) where {T,N}
428-
return TracedRArray{T,N}(
331+
function $jlop(@nospecialize(lhs::TracedRArray{T,0})) where {T}
332+
return TracedRArray{T,0}(
429333
(),
430334
MLIR.IR.result(MLIR.Dialects.stablehlo.$hloop(lhs.mlir_data), 1),
431335
size(lhs),

0 commit comments

Comments
 (0)