Skip to content

Commit 1f3e177

Browse files
refactor: improve precompile-friendliness of mul_worker
1 parent 2f9e372 commit 1f3e177

File tree

2 files changed

+17
-9
lines changed

2 files changed

+17
-9
lines changed

src/methods.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,14 @@ const basic_monadic = [-, +]
2929
const basic_diadic = [+, -, *, /, //, \, ^]
3030
#######################################################
3131

32+
@inline function safe_eltype(T::TypeT)
33+
if T <: AbstractArray
34+
T.parameters[1]::TypeT
35+
else
36+
T
37+
end
38+
end
39+
3240
@inline function promote_type_fast_path(T::TypeT, S::TypeT)
3341
if T === S
3442
return T

src/types.jl

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3294,7 +3294,7 @@ function (mwb::MulWorkerBuffer{T})(terms) where {T}
32943294
# so we take the eltype, since `scalar * scalar` and `scalar * array`
32953295
# both give the correct result regardless of whether the first element
32963296
# is a scalar or array.
3297-
type::TypeT = eltype(symtype(Const{T}(first(terms))))
3297+
type::TypeT = safe_eltype(symtype(Const{T}(first(terms))))
32983298

32993299
for term in terms
33003300
term = unwrap(term)
@@ -3330,8 +3330,8 @@ function (mwb::MulWorkerBuffer{T})(terms) where {T}
33303330
filter!(kvp -> !iszero(kvp[2]), num_dict)
33313331
filter!(kvp -> !iszero(kvp[2]), den_dict)
33323332

3333-
ntrivialcoeff = isone(num_coeff[])::Bool
3334-
dtrivialcoeff = isone(den_coeff[])::Bool
3333+
ntrivialcoeff = _isone(num_coeff[])::Bool
3334+
dtrivialcoeff = _isone(den_coeff[])::Bool
33353335
ntrivialdict = isempty(num_dict)
33363336
dtrivialdict = isempty(den_dict)
33373337
ntrivial = ntrivialcoeff && ntrivialdict
@@ -3342,7 +3342,7 @@ function (mwb::MulWorkerBuffer{T})(terms) where {T}
33423342
elseif ntrivialdict
33433343
num = Const{T}(num_coeff[])
33443344
else
3345-
num = Mul{T}(num_coeff[], num_dict; type = eltype(type)::TypeT)
3345+
num = Mul{T}(num_coeff[], num_dict; type = safe_eltype(type)::TypeT)
33463346
@match num begin
33473347
BSImpl.AddMul(; dict) && if dict === num_dict end => begin
33483348
mwb.num_dict = ACDict{T}()
@@ -3355,7 +3355,7 @@ function (mwb::MulWorkerBuffer{T})(terms) where {T}
33553355
elseif dtrivialdict
33563356
den = Const{T}(den_coeff[])
33573357
else
3358-
den = Mul{T}(den_coeff[], den_dict; type = eltype(type)::TypeT)
3358+
den = Mul{T}(den_coeff[], den_dict; type = safe_eltype(type)::TypeT)
33593359
@match den begin
33603360
BSImpl.AddMul(; dict) && if dict === den_dict end => begin
33613361
mwb.den_dict = ACDict{T}()
@@ -3369,7 +3369,7 @@ function (mwb::MulWorkerBuffer{T})(terms) where {T}
33693369
elseif dtrivial
33703370
result = num
33713371
else
3372-
result = Div{T}(num, den, false; type = eltype(type)::TypeT)
3372+
result = Div{T}(num, den, false; type = safe_eltype(type)::TypeT)
33733373
end
33743374

33753375
isempty(arrterms) && return result
@@ -3386,7 +3386,7 @@ function (mwb::MulWorkerBuffer{T})(terms) where {T}
33863386
@match acc_arrterm begin
33873387
BSImpl.Term(; f, args) && if f === (^) && isconst(args[2]) end => begin
33883388
acc_arrterm = args[1]
3389-
acc_pow = unwrap_const(args[2])
3389+
acc_pow = unwrap_const(args[2])::Number
33903390
end
33913391
_ => nothing
33923392
end
@@ -3403,13 +3403,13 @@ function (mwb::MulWorkerBuffer{T})(terms) where {T}
34033403
acc_pow += 1
34043404
continue
34053405
end
3406-
push!(new_arrterms, isone(acc_pow) ? acc_arrterm : (acc_arrterm ^ acc_pow))
3406+
push!(new_arrterms, _isone(acc_pow) ? acc_arrterm : (acc_arrterm ^ acc_pow))
34073407
acc_arrterm = cur_arrterm
34083408
acc_pow = 1
34093409
end
34103410
end
34113411
end
3412-
push!(new_arrterms, isone(acc_pow) ? acc_arrterm : (acc_arrterm ^ acc_pow))
3412+
push!(new_arrterms, _isone(acc_pow) ? acc_arrterm : (acc_arrterm ^ acc_pow))
34133413
if length(new_arrterms) == 1
34143414
return new_arrterms[1]
34153415
end

0 commit comments

Comments
 (0)