|
37 | 37 |
|
38 | 38 | # Custom version of `AbstractTrees.printnode` to |
39 | 39 | # avoid type piracy when overloading on `AbstractNamedDimsArray`. |
40 | | -printnode(io::IO, x) = AbstractTrees.printnode(io, x) |
41 | | -function printnode(io::IO, a::AbstractNamedDimsArray) |
| 40 | +printnode_nameddims(io::IO, x) = AbstractTrees.printnode(io, x) |
| 41 | +function printnode_nameddims(io::IO, a::AbstractNamedDimsArray) |
42 | 42 | show(io, collect(dimnames(a))) |
43 | 43 | return nothing |
44 | 44 | end |
@@ -149,6 +149,37 @@ function map_arguments_lazy(f, a) |
149 | 149 | return error("Variant not supported.") |
150 | 150 | end |
151 | 151 | end |
| 152 | +function substitute_lazy(a, substitutions::AbstractDict) |
| 153 | + haskey(substitutions, a) && return substitutions[a] |
| 154 | + !iscall(a) && return a |
| 155 | + return map_arguments(arg -> substitute(arg, substitutions), a) |
| 156 | +end |
| 157 | +function substitute_lazy(a, substitutions) |
| 158 | + return substitute(a, Dict(substitutions)) |
| 159 | +end |
| 160 | +function printnode_lazy(io, a) |
| 161 | + # Use `printnode_nameddims` to avoid type piracy, |
| 162 | + # since it overloads on `AbstractNamedDimsArray`. |
| 163 | + return printnode_nameddims(io, unwrap(a)) |
| 164 | +end |
| 165 | +function show_lazy(io::IO, a) |
| 166 | + if !iscall(a) |
| 167 | + return show(io, unwrap(a)) |
| 168 | + else |
| 169 | + return AbstractTrees.printnode(io, a) |
| 170 | + end |
| 171 | +end |
| 172 | +function show_lazy(io::IO, mime::MIME"text/plain", a) |
| 173 | + summary(io, a) |
| 174 | + println(io, ":") |
| 175 | + if !iscall(a) |
| 176 | + show(io, mime, unwrap(a)) |
| 177 | + return nothing |
| 178 | + else |
| 179 | + show(io, a) |
| 180 | + return nothing |
| 181 | + end |
| 182 | +end |
152 | 183 |
|
153 | 184 | # Lazy broadcasting. |
154 | 185 | struct LazyNamedDimsArrayStyle <: AbstractNamedDimsArrayStyle{Any} end |
@@ -185,7 +216,7 @@ TermInterface.head(m::Applied) = operation(m) |
185 | 216 | TermInterface.iscall(m::Applied) = true |
186 | 217 | TermInterface.isexpr(m::Applied) = iscall(m) |
187 | 218 | function Base.show(io::IO, m::Applied) |
188 | | - args = map(arg -> sprint(printnode, arg), arguments(m)) |
| 219 | + args = map(arg -> sprint(AbstractTrees.printnode, arg), arguments(m)) |
189 | 220 | print(io, "(", join(args, " $(operation(m)) "), ")") |
190 | 221 | return nothing |
191 | 222 | end |
@@ -301,37 +332,20 @@ end |
301 | 332 | function map_arguments(f, a::LazyNamedDimsArray) |
302 | 333 | return map_arguments_lazy(f, a) |
303 | 334 | end |
304 | | - |
305 | | -function substitute(a::LazyNamedDimsArray, substitutions::AbstractDict) |
306 | | - haskey(substitutions, a) && return substitutions[a] |
307 | | - !iscall(a) && return a |
308 | | - return map_arguments(arg -> substitute(arg, substitutions), a) |
309 | | -end |
310 | 335 | function substitute(a::LazyNamedDimsArray, substitutions) |
311 | | - return substitute(a, Dict(substitutions)) |
312 | | -end |
313 | | - |
314 | | -function printnode(io::IO, a::LazyNamedDimsArray) |
315 | | - return printnode(io, unwrap(a)) |
| 336 | + return substitute_lazy(a, substitutions) |
316 | 337 | end |
317 | 338 | function AbstractTrees.printnode(io::IO, a::LazyNamedDimsArray) |
318 | | - return printnode(io, a) |
| 339 | + return printnode_lazy(io, a) |
| 340 | +end |
| 341 | +function printnode_nameddims(io::IO, a::LazyNamedDimsArray) |
| 342 | + return printnode_lazy(io, a) |
319 | 343 | end |
320 | 344 | function Base.show(io::IO, a::LazyNamedDimsArray) |
321 | | - if !iscall(a) |
322 | | - return show(io, unwrap(a)) |
323 | | - else |
324 | | - return printnode(io, a) |
325 | | - end |
| 345 | + return show_lazy(io, a) |
326 | 346 | end |
327 | 347 | function Base.show(io::IO, mime::MIME"text/plain", a::LazyNamedDimsArray) |
328 | | - if !iscall(a) |
329 | | - @invoke show(io, mime, a::AbstractNamedDimsArray) |
330 | | - return nothing |
331 | | - else |
332 | | - show(io, a) |
333 | | - return nothing |
334 | | - end |
| 348 | + return show_lazy(io, mime, a) |
335 | 349 | end |
336 | 350 |
|
337 | 351 | function Base.:*(a::LazyNamedDimsArray) |
@@ -427,13 +441,16 @@ const SymbolicNamedDimsArray{T, N, Parent <: SymbolicArray{T, N}, DimNames} = |
427 | 441 | function symnameddims(name) |
428 | 442 | return lazy(NamedDimsArray(SymbolicArray(name), ())) |
429 | 443 | end |
430 | | -function printnode(io::IO, a::SymbolicNamedDimsArray) |
| 444 | +function AbstractTrees.printnode(io::IO, a::SymbolicNamedDimsArray) |
431 | 445 | print(io, symname(dename(a))) |
432 | 446 | if ndims(a) > 0 |
433 | 447 | print(io, "[", join(dimnames(a), ","), "]") |
434 | 448 | end |
435 | 449 | return nothing |
436 | 450 | end |
| 451 | +function printnode_nameddims(io::IO, a::SymbolicNamedDimsArray) |
| 452 | + return AbstractTrees.printnode(io, a) |
| 453 | +end |
437 | 454 | function Base.:(==)(a::SymbolicNamedDimsArray, b::SymbolicNamedDimsArray) |
438 | 455 | return issetequal(inds(a), inds(b)) && dename(a) == dename(b) |
439 | 456 | end |
|
0 commit comments