From 5ce43dd208733ce7f35377e1c33efca5ec45def7 Mon Sep 17 00:00:00 2001 From: Willow Ahrens Date: Wed, 9 Apr 2025 14:47:54 -0400 Subject: [PATCH 1/2] add Shard Levels --- ext/SparseArraysExt.jl | 104 ++-- src/Finch.jl | 2 + src/architecture.jl | 61 +- src/environment.jl | 2 +- src/execute.jl | 40 +- src/tensors/levels/dense_levels.jl | 8 +- src/tensors/levels/dense_rle_levels.jl | 102 ++-- src/tensors/levels/element_levels.jl | 2 +- src/tensors/levels/separate_levels.jl | 2 +- src/tensors/levels/shard_levels.jl | 589 +++++++++++++++++++ src/tensors/levels/sparse_band_levels.jl | 52 +- src/tensors/levels/sparse_bytemap_levels.jl | 44 +- src/tensors/levels/sparse_coo_levels.jl | 54 +- src/tensors/levels/sparse_dict_levels.jl | 36 +- src/tensors/levels/sparse_interval_levels.jl | 18 +- src/tensors/levels/sparse_list_levels.jl | 60 +- src/tensors/levels/sparse_rle_levels.jl | 82 +-- src/tensors/levels/sparse_vbl_levels.jl | 44 +- 18 files changed, 966 insertions(+), 336 deletions(-) create mode 100644 src/tensors/levels/shard_levels.jl diff --git a/ext/SparseArraysExt.jl b/ext/SparseArraysExt.jl index 909ffe5b1..6992dcbf4 100644 --- a/ext/SparseArraysExt.jl +++ b/ext/SparseArraysExt.jl @@ -84,8 +84,8 @@ end ptr idx val - qos_fill - qos_stop + qos_used + qos_alloc prev_pos end @@ -127,14 +127,14 @@ function Finch.virtualize(ctx, ex, ::Type{<:SparseMatrixCSC{Tv,Ti}}, tag=:tns) w $val = $tag.nzval end, ) - qos_fill = freshen(ctx, tag, :_qos_fill) - qos_stop = freshen(ctx, tag, :_qos_stop) + qos_used = freshen(ctx, tag, :_qos_used) + qos_alloc = freshen(ctx, tag, :_qos_alloc) prev_pos = freshen(ctx, tag, :_prev_pos) shape = [ VirtualExtent(literal(1), value(m, Ti)), VirtualExtent(literal(1), value(n, Ti)) ] VirtualSparseMatrixCSC( - tag, Tv, Ti, shape, ptr, idx, val, qos_fill, qos_stop, prev_pos + tag, Tv, Ti, shape, ptr, idx, val, qos_used, qos_alloc, prev_pos ) end @@ -149,8 +149,8 @@ function distribute( distribute_buffer(ctx, arr.ptr, arch, style), distribute_buffer(ctx, arr.idx, arch, style), distribute_buffer(ctx, arr.val, arch, style), - arr.qos_fill, - arr.qos_stop, + arr.qos_used, + arr.qos_alloc, arr.prev_pos, ) end @@ -166,8 +166,8 @@ function Finch.declare!(ctx::AbstractCompiler, arr::VirtualSparseMatrixCSC, init push_preamble!( ctx, quote - $(arr.qos_fill) = $(Tp(0)) - $(arr.qos_stop) = $(Tp(0)) + $(arr.qos_used) = $(Tp(0)) + $(arr.qos_alloc) = $(Tp(0)) resize!($(arr.ptr), $pos_stop + 1) fill_range!($(arr.ptr), $(Tp(0)), 1, $pos_stop + 1) $(arr.ptr)[1] = $(Tp(1)) @@ -187,7 +187,7 @@ end function Finch.freeze!(ctx::AbstractCompiler, arr::VirtualSparseMatrixCSC) p = freshen(ctx, :p) pos_stop = ctx(getstop(virtual_size(ctx, arr)[2])) - qos_stop = freshen(ctx, :qos_stop) + qos_alloc = freshen(ctx, :qos_alloc) push_preamble!( ctx, quote @@ -195,9 +195,9 @@ function Finch.freeze!(ctx::AbstractCompiler, arr::VirtualSparseMatrixCSC) for $p in 1:($pos_stop) $(arr.ptr)[$p + 1] += $(arr.ptr)[$p] end - $qos_stop = $(arr.ptr)[$pos_stop + 1] - 1 - resize!($(arr.idx), $qos_stop) - resize!($(arr.val), $qos_stop) + $qos_alloc = $(arr.ptr)[$pos_stop + 1] - 1 + resize!($(arr.idx), $qos_alloc) + resize!($(arr.val), $qos_alloc) end, ) return arr @@ -206,19 +206,19 @@ end function Finch.thaw!(ctx::AbstractCompiler, arr::VirtualSparseMatrixCSC) p = freshen(ctx, :p) pos_stop = ctx(getstop(virtual_size(ctx, arr)[2])) - qos_stop = freshen(ctx, :qos_stop) + qos_alloc = freshen(ctx, :qos_alloc) push_preamble!( ctx, quote - $(arr.qos_fill) = $(arr.ptr)[$pos_stop + 1] - 1 - $(arr.qos_stop) = $(arr.qos_fill) - $qos_stop = $(arr.qos_fill) + $(arr.qos_used) = $(arr.ptr)[$pos_stop + 1] - 1 + $(arr.qos_alloc) = $(arr.qos_used) + $qos_alloc = $(arr.qos_used) $( if issafe(get_mode_flag(ctx)) quote $(arr.prev_pos) = Finch.scansearch( - $(arr.ptr), $(arr.qos_stop) + 1, 1, $pos_stop + $(arr.ptr), $(arr.qos_alloc) + 1, 1, $pos_stop ) - 1 end end @@ -350,12 +350,12 @@ function Finch.unfurl( j = tns.j Tp = arr.Ti qos = freshen(ctx, tag, :_qos) - qos_fill = arr.qos_fill - qos_stop = arr.qos_stop + qos_used = arr.qos_used + qos_alloc = arr.qos_alloc dirty = freshen(ctx, tag, :dirty) Thunk(; preamble = quote - $qos = $qos_fill + 1 + $qos = $qos_used + 1 $(if issafe(get_mode_flag(ctx)) quote $(arr.prev_pos) < $(ctx(j)) || throw(FinchProtocolError("SparseMatrixCSCs cannot be updated multiple times")) @@ -365,10 +365,10 @@ function Finch.unfurl( body = (ctx) -> Lookup(; body=(ctx, idx) -> Thunk(; preamble = quote - if $qos > $qos_stop - $qos_stop = max($qos_stop << 1, 1) - Finch.resize_if_smaller!($(arr.idx), $qos_stop) - Finch.resize_if_smaller!($(arr.val), $qos_stop) + if $qos > $qos_alloc + $qos_alloc = max($qos_alloc << 1, 1) + Finch.resize_if_smaller!($(arr.idx), $qos_alloc) + Finch.resize_if_smaller!($(arr.val), $qos_alloc) end $dirty = false end, @@ -387,8 +387,8 @@ function Finch.unfurl( ) ), epilogue = quote - $(arr.ptr)[$(ctx(j)) + 1] += $qos - $qos_fill - 1 - $qos_fill = $qos - 1 + $(arr.ptr)[$(ctx(j)) + 1] += $qos - $qos_used - 1 + $qos_used = $qos - 1 end, ) end @@ -429,8 +429,8 @@ end shape idx val - qos_fill - qos_stop + qos_used + qos_alloc end function Finch.virtual_size(ctx::AbstractCompiler, arr::VirtualSparseVector) @@ -460,9 +460,9 @@ function Finch.virtualize(ctx, ex, ::Type{<:SparseVector{Tv,Ti}}, tag=:tns) wher $val = $tag.nzval end, ) - qos_fill = freshen(ctx, tag, :_qos_fill) - qos_stop = freshen(ctx, tag, :_qos_stop) - VirtualSparseVector(tag, Tv, Ti, shape, idx, val, qos_fill, qos_stop) + qos_used = freshen(ctx, tag, :_qos_used) + qos_alloc = freshen(ctx, tag, :_qos_alloc) + VirtualSparseVector(tag, Tv, Ti, shape, idx, val, qos_used, qos_alloc) end function distribute( @@ -475,8 +475,8 @@ function distribute( arr.shape, distribute_buffer(ctx, arr.idx, arch, style), distribute_buffer(ctx, arr.val, arch, style), - arr.qos_fill, - arr.qos_stop, + arr.qos_used, + arr.qos_alloc, ) end @@ -490,8 +490,8 @@ function Finch.declare!(ctx::AbstractCompiler, arr::VirtualSparseVector, init) push_preamble!( ctx, quote - $(arr.qos_fill) = $(Tp(0)) - $(arr.qos_stop) = $(Tp(0)) + $(arr.qos_used) = $(Tp(0)) + $(arr.qos_alloc) = $(Tp(0)) end, ) return arr @@ -499,13 +499,13 @@ end function Finch.freeze!(ctx::AbstractCompiler, arr::VirtualSparseVector) p = freshen(ctx, :p) - qos_stop = freshen(ctx, :qos_stop) + qos_alloc = freshen(ctx, :qos_alloc) push_preamble!( ctx, quote - $qos_stop = $(ctx(arr.qos_fill)) - resize!($(arr.idx), $qos_stop) - resize!($(arr.val), $qos_stop) + $qos_alloc = $(ctx(arr.qos_used)) + resize!($(arr.idx), $qos_alloc) + resize!($(arr.val), $qos_alloc) end, ) return arr @@ -513,13 +513,13 @@ end function Finch.thaw!(ctx::AbstractCompiler, arr::VirtualSparseVector) p = freshen(ctx, :p) - qos_stop = freshen(ctx, :qos_stop) + qos_alloc = freshen(ctx, :qos_alloc) push_preamble!( ctx, quote - $(arr.qos_fill) = length($(arr.idx)) - $(arr.qos_stop) = $(arr.qos_fill) - $qos_stop = $(arr.qos_fill) + $(arr.qos_used) = length($(arr.idx)) + $(arr.qos_alloc) = $(arr.qos_used) + $qos_alloc = $(arr.qos_used) end, ) return arr @@ -593,23 +593,23 @@ function Finch.unfurl( tag = arr.tag Tp = arr.Ti qos = freshen(ctx, tag, :_qos) - qos_fill = arr.qos_fill - qos_stop = arr.qos_stop + qos_used = arr.qos_used + qos_alloc = arr.qos_alloc dirty = freshen(ctx, tag, :dirty) Unfurled(; arr=arr, body=Thunk(; preamble = quote - $qos = $qos_fill + 1 + $qos = $qos_used + 1 end, body = (ctx) -> Lookup(; body=(ctx, idx) -> Thunk(; preamble = quote - if $qos > $qos_stop - $qos_stop = max($qos_stop << 1, 1) - Finch.resize_if_smaller!($(arr.idx), $qos_stop) - Finch.resize_if_smaller!($(arr.val), $qos_stop) + if $qos > $qos_alloc + $qos_alloc = max($qos_alloc << 1, 1) + Finch.resize_if_smaller!($(arr.idx), $qos_alloc) + Finch.resize_if_smaller!($(arr.val), $qos_alloc) end $dirty = false end, @@ -623,7 +623,7 @@ function Finch.unfurl( ) ), epilogue = quote - $qos_fill = $qos - 1 + $qos_used = $qos - 1 end, ), ) diff --git a/src/Finch.jl b/src/Finch.jl index 25df496df..5124eba93 100644 --- a/src/Finch.jl +++ b/src/Finch.jl @@ -42,6 +42,7 @@ export Dense, DenseLevel export Element, ElementLevel export AtomicElement, AtomicElementLevel export Separate, SeparateLevel +export Shard, ShardLevel export Mutex, MutexLevel export Pattern, PatternLevel export Scalar, SparseScalar, ShortCircuitScalar, SparseShortCircuitScalar @@ -142,6 +143,7 @@ include("tensors/levels/dense_rle_levels.jl") include("tensors/levels/element_levels.jl") include("tensors/levels/atomic_element_levels.jl") include("tensors/levels/separate_levels.jl") +include("tensors/levels/shard_levels.jl") include("tensors/levels/mutex_levels.jl") include("tensors/levels/pattern_levels.jl") include("tensors/masks.jl") diff --git a/src/architecture.jl b/src/architecture.jl index dab10d145..f58ce77e5 100644 --- a/src/architecture.jl +++ b/src/architecture.jl @@ -41,12 +41,14 @@ abstract type AbstractVirtualTask end Return the number of tasks on the device dev. """ function get_num_tasks end + """ get_task_num(task::AbstractTask) Return the task number of `task`. """ function get_task_num end + """ get_device(task::AbstractTask) @@ -61,6 +63,25 @@ Return the task which spawned `task`. """ function get_parent_task end +get_num_tasks(ctx::AbstractCompiler) = get_num_tasks(get_task(ctx)) +get_num_tasks(task::AbstractTask) = get_num_tasks(get_device(task)) +get_task_num(ctx::AbstractCompiler) = get_task_num(get_task(ctx)) +get_device(ctx::AbstractCompiler) = get_device(get_task(ctx)) +get_parent_task(ctx::AbstractCompiler) = get_parent_task(get_task(ctx)) + +function is_on_device(ctx::AbstractCompiler, dev) + res = false + task = get_task(ctx) + while task != nothing + if get_device(task) == dev + res = true + break + end + task = get_parent_task(task) + end + return res +end + """ aquire_lock!(dev::AbstractDevice, val) @@ -92,20 +113,35 @@ function make_lock end """ Serial() -A device that represents a serial CPU execution. +A Task that represents a serial CPU execution. """ -struct Serial <: AbstractTask end +struct Serial <: AbstractDevice end const serial = Serial() -get_device(::Serial) = CPU(1) -get_parent_task(::Serial) = nothing -get_task_num(::Serial) = 1 +get_num_tasks(::Serial) = 1 struct VirtualSerial <: AbstractVirtualTask end virtualize(ctx, ex, ::Type{Serial}) = VirtualSerial() lower(ctx::AbstractCompiler, task::VirtualSerial, ::DefaultStyle) = :(Serial()) FinchNotation.finch_leaf(device::VirtualSerial) = virtual(device) -get_device(::VirtualSerial) = VirtualCPU(nothing, 1) -get_parent_task(::VirtualSerial) = nothing -get_task_num(::VirtualSerial) = literal(1) +get_num_tasks(::VirtualSerial) = literal(1) +Base.:(==)(::Serial, ::Serial) = true +Base.:(==)(::VirtualSerial, ::VirtualSerial) = true + +""" + SerialTask() + +A Task that represents a serial CPU execution. +""" +struct SerialTask <: AbstractDevice end +get_device(::SerialTask) = Serial() +get_parent_task(::SerialTask) = nothing +get_task_num(::SerialTask) = 1 +struct VirtualSerialTask <: AbstractVirtualTask end +virtualize(ctx, ex, ::Type{SerialTask}) = VirtualSerialTask() +lower(ctx::AbstractCompiler, task::VirtualSerialTask, ::DefaultStyle) = :(SerialTask()) +FinchNotation.finch_leaf(device::VirtualSerialTask) = virtual(device) +get_device(::VirtualSerialTask) = VirtualSerial() +get_parent_task(::VirtualSerialTask) = nothing +get_task_num(::VirtualSerialTask) = literal(1) struct SerialMemory end struct VirtualSerialMemory end @@ -148,6 +184,8 @@ function lower(ctx::AbstractCompiler, device::VirtualCPU, ::DefaultStyle) something(device.ex, :(CPU($(ctx(device.n))))) end get_num_tasks(::VirtualCPU) = literal(1) +Base.:(==)(::CPU, ::CPU) = true +Base.:(==)(::VirtualCPU, ::VirtualCPU) = true #This is not strictly true. A better approach would name devices, and give them parents so that we can be sure to parallelize through the processor hierarchy. FinchNotation.finch_leaf(device::VirtualCPU) = virtual(device) @@ -212,7 +250,7 @@ function transfer(device::CPULocalMemory, arr::AbstractArray) CPULocalArray{A}(mem.device, [copy(arr) for _ in 1:(mem.device.n)]) end function transfer(task::CPUThread, arr::CPULocalArray) - if get_device(task) === arr.device + if get_device(task) == arr.device temp = arr.data[task.tid] return temp else @@ -223,6 +261,7 @@ function transfer(dst::AbstractArray, arr::AbstractArray) return arr end + """ transfer(device, arr) @@ -484,8 +523,8 @@ for T in [ end end -function virtual_parallel_region(f, ctx, ::Serial) - contain(f, ctx) +function virtual_parallel_region(f, ctx, ::VirtualSerial) + contain(f, ctx; task=VirtualSerialTask()) end function virtual_parallel_region(f, ctx, device::VirtualCPU) diff --git a/src/environment.jl b/src/environment.jl index a845629eb..ede09be8f 100644 --- a/src/environment.jl +++ b/src/environment.jl @@ -44,7 +44,7 @@ variable names in the generated code of the executing environment. namespace::Namespace = Namespace() preamble::Vector{Any} = [] epilogue::Vector{Any} = [] - task = VirtualSerial() + task = VirtualSerialTask() end """ diff --git a/src/execute.jl b/src/execute.jl index 863d6f804..f116e7a53 100644 --- a/src/execute.jl +++ b/src/execute.jl @@ -181,18 +181,18 @@ macro finch(opts_ex...) (opts, ex) = (opts_ex[1:(end - 1)], opts_ex[end]) prgm = FinchNotation.finch_parse_instance(ex) prgm = :( - $(FinchNotation.block_instance)( - $prgm, - $(FinchNotation.yieldbind_instance)( - $( - map( - FinchNotation.variable_instance, - FinchNotation.finch_parse_default_yieldbind(ex), - )... + $(FinchNotation.block_instance)( + $prgm, + $(FinchNotation.yieldbind_instance)( + $( + map( + FinchNotation.variable_instance, + FinchNotation.finch_parse_default_yieldbind(ex), + )... + ), ), - ), + ) ) -) res = esc(:res) thunk = quote res = $execute($prgm, ; $(map(esc, opts)...)) @@ -230,18 +230,18 @@ macro finch_code(opts_ex...) (opts, ex) = (opts_ex[1:(end - 1)], opts_ex[end]) prgm = FinchNotation.finch_parse_instance(ex) prgm = :( - $(FinchNotation.block_instance)( - $prgm, - $(FinchNotation.yieldbind_instance)( - $( - map( - FinchNotation.variable_instance, - FinchNotation.finch_parse_default_yieldbind(ex), - )... + $(FinchNotation.block_instance)( + $prgm, + $(FinchNotation.yieldbind_instance)( + $( + map( + FinchNotation.variable_instance, + FinchNotation.finch_parse_default_yieldbind(ex), + )... + ), ), - ), + ) ) -) return quote unquote_literals( dataflow( diff --git a/src/tensors/levels/dense_levels.jl b/src/tensors/levels/dense_levels.jl index f76ca504d..c89861ccf 100644 --- a/src/tensors/levels/dense_levels.jl +++ b/src/tensors/levels/dense_levels.jl @@ -194,15 +194,15 @@ end function assemble_level!(ctx, lvl::VirtualDenseLevel, pos_start, pos_stop) qos_start = call(+, call(*, call(-, pos_start, lvl.Ti(1)), lvl.shape), 1) - qos_stop = call(*, pos_stop, lvl.shape) - assemble_level!(ctx, lvl.lvl, qos_start, qos_stop) + qos_alloc = call(*, pos_stop, lvl.shape) + assemble_level!(ctx, lvl.lvl, qos_start, qos_alloc) end supports_reassembly(::VirtualDenseLevel) = true function reassemble_level!(ctx, lvl::VirtualDenseLevel, pos_start, pos_stop) qos_start = call(+, call(*, call(-, pos_start, lvl.Ti(1)), lvl.shape), 1) - qos_stop = call(*, pos_stop, lvl.shape) - reassemble_level!(ctx, lvl.lvl, qos_start, qos_stop) + qos_alloc = call(*, pos_stop, lvl.shape) + reassemble_level!(ctx, lvl.lvl, qos_start, qos_alloc) lvl end diff --git a/src/tensors/levels/dense_rle_levels.jl b/src/tensors/levels/dense_rle_levels.jl index 1364ac031..7fa29226a 100644 --- a/src/tensors/levels/dense_rle_levels.jl +++ b/src/tensors/levels/dense_rle_levels.jl @@ -206,8 +206,8 @@ mutable struct VirtualRunListLevel <: AbstractVirtualLevel lvl Ti shape - qos_fill - qos_stop + qos_used + qos_alloc ptr right buf @@ -240,12 +240,12 @@ function virtualize( # 1. prevpos is the last position written (initially 0) # 2. i_prev is the last index written (initially shape) # 3. for all p in 1:prevpos-1, ptr[p] is the number of runs in that position - # 4. qos_fill is the position of the last index written + # 4. qos_used is the position of the last index written tag = freshen(ctx, tag) stop = freshen(ctx, tag, :_stop) - qos_fill = freshen(ctx, tag, :_qos_fill) - qos_stop = freshen(ctx, tag, :_qos_stop) + qos_used = freshen(ctx, tag, :_qos_used) + qos_alloc = freshen(ctx, tag, :_qos_alloc) dirty = freshen(ctx, tag, :_dirty) ptr = freshen(ctx, tag, :_ptr) right = freshen(ctx, tag, :_right) @@ -266,7 +266,7 @@ function virtualize( lvl_2 = virtualize(ctx, :($tag.lvl), Lvl, tag) buf = virtualize(ctx, :($tag.buf), Lvl, tag) VirtualRunListLevel( - tag, lvl_2, Ti, shape, qos_fill, qos_stop, ptr, right, buf, prev_pos, i_prev, + tag, lvl_2, Ti, shape, qos_used, qos_alloc, ptr, right, buf, prev_pos, i_prev, merge, ) end @@ -292,8 +292,8 @@ function distribute_level( distribute_level(ctx, lvl.lvl, arch, diff, style), lvl.Ti, lvl.shape, - lvl.qos_fill, - lvl.qos_stop, + lvl.qos_used, + lvl.qos_alloc, distribute_buffer(ctx, lvl.ptr, arch, style), distribute_buffer(ctx, lvl.right, arch, style), distribute_level(ctx, lvl.buf, arch, diff, style), @@ -312,8 +312,8 @@ function redistribute(ctx::AbstractCompiler, lvl::VirtualRunListLevel, diff) redistribute(ctx, lvl.lvl, diff), lvl.Ti, lvl.shape, - lvl.qos_fill, - lvl.qos_stop, + lvl.qos_used, + lvl.qos_alloc, lvl.ptr, lvl.right, redistribute(ctx, lvl.buf, diff), @@ -349,8 +349,8 @@ function declare_level!(ctx::AbstractCompiler, lvl::VirtualRunListLevel, pos, in push_preamble!( ctx, quote - $(lvl.qos_fill) = $(Tp(0)) - $(lvl.qos_stop) = $(Tp(0)) + $(lvl.qos_used) = $(Tp(0)) + $(lvl.qos_alloc) = $(Tp(0)) $(lvl.i_prev) = $(Ti(1)) - $unit $(lvl.prev_pos) = $(Tp(1)) end, @@ -373,16 +373,16 @@ function freeze_level!(ctx::AbstractCompiler, lvl::VirtualRunListLevel, pos_stop (lvl.buf, lvl.lvl) = (lvl.lvl, lvl.buf) p = freshen(ctx, :p) pos_stop = ctx(cache!(ctx, :pos_stop, simplify(ctx, pos_stop))) - qos_stop = freshen(ctx, :qos_stop) + qos_alloc = freshen(ctx, :qos_alloc) push_preamble!(ctx, quote resize!($(lvl.ptr), $pos_stop + 1) for $p = 1:$pos_stop $(lvl.ptr)[$p + 1] += $(lvl.ptr)[$p] end - $qos_stop = $(lvl.ptr)[$pos_stop + 1] - 1 - resize!($(lvl.right), $qos_stop) + $qos_alloc = $(lvl.ptr)[$pos_stop + 1] - 1 + resize!($(lvl.right), $qos_alloc) end) - lvl.lvl = freeze_level!(ctx, lvl.lvl, value(qos_stop)) + lvl.lvl = freeze_level!(ctx, lvl.lvl, value(qos_alloc)) return lvl end =# @@ -393,30 +393,30 @@ function freeze_level!(ctx::AbstractCompiler, lvl::VirtualRunListLevel, pos_stop pos_stop = ctx(cache!(ctx, :pos_stop, simplify(ctx, pos_stop))) Ti = lvl.Ti pos_2 = freshen(ctx, tag, :_pos) - qos_stop = lvl.qos_stop - qos_fill = lvl.qos_fill + qos_alloc = lvl.qos_alloc + qos_used = lvl.qos_used qos = freshen(ctx, :qos) unit = ctx(get_smallest_measure(virtual_level_size(ctx, lvl)[end])) push_preamble!( ctx, quote - $qos = $(lvl.qos_fill) + $qos = $(lvl.qos_used) #if we did not write something to finish out the last run, we need to fill that in $qos += $(lvl.i_prev) < $(ctx(lvl.shape)) #and all the runs after that $qos += $(pos_stop) - $(lvl.prev_pos) - if $qos > $qos_stop - $qos_stop = $qos - Finch.resize_if_smaller!($(lvl.right), $qos_stop) + if $qos > $qos_alloc + $qos_alloc = $qos + Finch.resize_if_smaller!($(lvl.right), $qos_alloc) Finch.fill_range!( - $(lvl.right), $(ctx(lvl.shape)), $qos_fill + 1, $qos_stop + $(lvl.right), $(ctx(lvl.shape)), $qos_used + 1, $qos_alloc ) $(contain( ctx_2 -> assemble_level!( ctx_2, lvl.buf, - call(+, value(qos_fill, Tp), Tp(1)), - value(qos_stop, Tp), + call(+, value(qos_used, Tp), Tp(1)), + value(qos_alloc, Tp), ), ctx, )) @@ -425,11 +425,11 @@ function freeze_level!(ctx::AbstractCompiler, lvl::VirtualRunListLevel, pos_stop for $p in 1:($pos_stop) $(lvl.ptr)[$p + 1] += $(lvl.ptr)[$p] end - $qos_stop = $(lvl.ptr)[$pos_stop + 1] - 1 + $qos_alloc = $(lvl.ptr)[$pos_stop + 1] - 1 end, ) if lvl.merge - lvl.buf = freeze_level!(ctx, lvl.buf, value(qos_stop)) + lvl.buf = freeze_level!(ctx, lvl.buf, value(qos_alloc)) lvl.lvl = declare_level!( ctx, lvl.lvl, literal(1), literal(virtual_level_fill_value(lvl.buf)) ) @@ -444,7 +444,7 @@ function freeze_level!(ctx::AbstractCompiler, lvl::VirtualRunListLevel, pos_stop quote $(contain( ctx_2 -> - assemble_level!(ctx_2, lvl.lvl, value(1, Tp), value(qos_stop, Tp)), + assemble_level!(ctx_2, lvl.lvl, value(1, Tp), value(qos_alloc, Tp)), ctx, )) $q = 1 @@ -543,10 +543,10 @@ function freeze_level!(ctx::AbstractCompiler, lvl::VirtualRunListLevel, pos_stop $(lvl.ptr)[$p + 1] = $q_2 end resize!($(lvl.right), $q_2 - 1) - $qos_stop = $q_2 - 1 + $qos_alloc = $q_2 - 1 end, ) - lvl.lvl = freeze_level!(ctx, lvl.lvl, value(qos_stop)) + lvl.lvl = freeze_level!(ctx, lvl.lvl, value(qos_alloc)) lvl.buf = declare_level!( ctx, lvl.buf, literal(1), literal(virtual_level_fill_value(lvl.buf)) ) @@ -556,11 +556,11 @@ function freeze_level!(ctx::AbstractCompiler, lvl::VirtualRunListLevel, pos_stop push_preamble!( ctx, quote - resize!($(lvl.right), $qos_stop) + resize!($(lvl.right), $qos_alloc) end, ) (lvl.buf, lvl.lvl) = (lvl.lvl, lvl.buf) - lvl.lvl = freeze_level!(ctx, lvl.lvl, value(qos_stop)) + lvl.lvl = freeze_level!(ctx, lvl.lvl, value(qos_alloc)) return lvl end end @@ -572,20 +572,20 @@ function thaw_level!(ctx::AbstractCompiler, lvl::VirtualRunListLevel, pos_stop) #= p = freshen(ctx, :p) pos_stop = ctx(cache!(ctx, :pos_stop, simplify(ctx, pos_stop))) - qos_stop = freshen(ctx, :qos_stop) + qos_alloc = freshen(ctx, :qos_alloc) unit = ctx(get_smallest_measure(virtual_level_size(ctx, lvl)[end])) push_preamble!(ctx, quote - $(lvl.qos_fill) = $(lvl.ptr)[$pos_stop + 1] - 1 - $(lvl.qos_stop) = $(lvl.qos_fill) - $(lvl.i_prev) = $(lvl.right)[$(lvl.qos_fill)] - $qos_stop = $(lvl.qos_fill) - $(lvl.prev_pos) = Finch.scansearch($(lvl.ptr), $(lvl.qos_stop) + 1, 1, $pos_stop) - 1 + $(lvl.qos_used) = $(lvl.ptr)[$pos_stop + 1] - 1 + $(lvl.qos_alloc) = $(lvl.qos_used) + $(lvl.i_prev) = $(lvl.right)[$(lvl.qos_used)] + $qos_alloc = $(lvl.qos_used) + $(lvl.prev_pos) = Finch.scansearch($(lvl.ptr), $(lvl.qos_alloc) + 1, 1, $pos_stop) - 1 for $p = $pos_stop:-1:1 $(lvl.ptr)[$p + 1] -= $(lvl.ptr)[$p] end end) (lvl.lvl, lvl.buf) = (lvl.buf, lvl.lvl) - lvl.buf = thaw_level!(ctx, lvl.buf, value(qos_stop)) + lvl.buf = thaw_level!(ctx, lvl.buf, value(qos_alloc)) return lvl =# end @@ -663,7 +663,7 @@ end # 1. prevpos is the last position written (initially 0) # 2. i_prev is the last index written (initially shape) # 3. for all p in 1:prevpos-1, ptr[p] is the number of runs in that position -# 4. qos_fill is the position of the last index written +# 4. qos_used is the position of the last index written function unfurl( ctx, @@ -677,8 +677,8 @@ function unfurl( Tp = postype(lvl) Ti = lvl.Ti qos = freshen(ctx, tag, :_qos) - qos_fill = lvl.qos_fill - qos_stop = lvl.qos_stop + qos_used = lvl.qos_used + qos_alloc = lvl.qos_alloc dirty = freshen(ctx, tag, :dirty) pos_2 = freshen(ctx, tag, :_pos) unit = ctx(get_smallest_measure(virtual_level_size(ctx, lvl)[end])) @@ -691,7 +691,7 @@ function unfurl( arr=fbr, body=Thunk(; preamble=quote - $qos = $qos_fill + 1 + $qos = $qos_used + 1 $( if issafe(get_mode_flag(ctx)) quote @@ -717,14 +717,14 @@ function unfurl( body=(ctx, ext) -> Thunk(; preamble = quote $qos_3 = $qos + ($(local_i_prev) < ($(ctx(getstart(ext))) - $unit)) - if $qos_3 > $qos_stop - $qos_2 = $qos_stop + 1 - while $qos_3 > $qos_stop - $qos_stop = max($qos_stop << 1, 1) + if $qos_3 > $qos_alloc + $qos_2 = $qos_alloc + 1 + while $qos_3 > $qos_alloc + $qos_alloc = max($qos_alloc << 1, 1) end - Finch.resize_if_smaller!($(lvl.right), $qos_stop) - Finch.fill_range!($(lvl.right), $(ctx(lvl.shape)), $qos_2, $qos_stop) - $(contain(ctx_2 -> assemble_level!(ctx_2, lvl.buf, value(qos_2, Tp), value(qos_stop, Tp)), ctx)) + Finch.resize_if_smaller!($(lvl.right), $qos_alloc) + Finch.fill_range!($(lvl.right), $(ctx(lvl.shape)), $qos_2, $qos_alloc) + $(contain(ctx_2 -> assemble_level!(ctx_2, lvl.buf, value(qos_2, Tp), value(qos_alloc, Tp)), ctx)) end $dirty = false end, @@ -746,7 +746,7 @@ function unfurl( $qos - $qos_set - ($(local_i_prev) == $(ctx(lvl.shape))) #the last run is accounted for already because ptr starts out at 1 $(lvl.prev_pos) = $(ctx(pos)) $(lvl.i_prev) = $(local_i_prev) - $qos_fill = $qos - 1 + $qos_used = $qos - 1 end end, ), diff --git a/src/tensors/levels/element_levels.jl b/src/tensors/levels/element_levels.jl index c429747d0..55ef387ae 100644 --- a/src/tensors/levels/element_levels.jl +++ b/src/tensors/levels/element_levels.jl @@ -43,7 +43,7 @@ end postype(::Type{<:ElementLevel{Vf,Tv,Tp}}) where {Vf,Tv,Tp} = Tp -function transfer(lvl::ElementLevel{Vf,Tv,Tp}, device, style) where {Vf,Tv,Tp} +function transfer(device, lvl::ElementLevel{Vf,Tv,Tp}) where {Vf,Tv,Tp} return ElementLevel{Vf,Tv,Tp}(transfer(device, lvl.val)) end diff --git a/src/tensors/levels/separate_levels.jl b/src/tensors/levels/separate_levels.jl index 68b341986..c0712862c 100644 --- a/src/tensors/levels/separate_levels.jl +++ b/src/tensors/levels/separate_levels.jl @@ -95,7 +95,7 @@ countstored_level(lvl::SeparateLevel, pos) = pos mutable struct VirtualSeparateLevel <: AbstractVirtualLevel tag - lvl # stand in for the sublevel for virutal resize, etc. + lvl # stand in for the sublevel for virtual resize, etc. val Tv Lvl diff --git a/src/tensors/levels/shard_levels.jl b/src/tensors/levels/shard_levels.jl new file mode 100644 index 000000000..6eedfae40 --- /dev/null +++ b/src/tensors/levels/shard_levels.jl @@ -0,0 +1,589 @@ +struct MultiChannelMemory{Device} <: AbstractDevice + device::Device + n::Int +end + +Base.:(==)(device::MultiChannelMemory, other::MultiChannelMemory) = + device.device == other.device + +get_num_tasks(device::MultiChannelMemory) = device.n +get_device(device::MultiChannelMemory) = device.device + +struct VirtualMultiChannelMemory <: AbstractVirtualDevice + device + n +end + +Base.:(==)(device::VirtualMultiChannelMemory, other::VirtualMultiChannelMemory) = + device.device == other.device + +get_num_tasks(device::VirtualMultiChannelMemory) = device.n +get_device(device::VirtualMultiChannelMemory) = device.device + +function virtualize(ctx, ex, ::Type{MultiChannelMemory{Device}}) where {Device} + device = virtualize(ctx, :($ex.device), Device) + n = freshen(ctx, :n) + push_preamble!(quote + $n = $ex.n + end) + VirtualMultiChannelMemory(device, n) +end + +function lower(ctx::AbstractCompiler, mem::VirtualMultiChannelMemory, ::DefaultStyle) + quote + MultiChannelMemory($(ctx(mem.device)), $(ctx(mem.n))) + end +end + +struct MemoryChannel{Device<:MultiChannelMemory, Parent} <: AbstractTask + t::Int + device::Device + Parent::Parent +end + +get_device(device::MemoryChannel) = device.device +get_parent_task(device::MemoryChannel) = device.parent +get_task_num(device::MemoryChannel) = device.t + +struct VirtualMemoryChannel <: AbstractVirtualTask + t + device + parent +end + +function virtualize(ctx, ex, ::Type{MemoryChannel{Device, Parent}}) where {Device, Parent} + device = virtualize(ctx, :($ex.device), Device) + parent = virtualize(ctx, :($ex.parent), Parent) + t = freshen(ctx, :t) + push_preamble!(quote + $t = $(ctx(ex.t)) + end) + VirtualMemoryChannel(device, t) +end + +function lower(ctx::AbstractCompiler, mem::VirtualMemoryChannel, ::DefaultStyle) + quote + MemoryChannel($(ctx(mem.t)), $(ctx(mem.device)), $(ctx(mem.parent))) + end +end + +struct MultiChannelBuffer{A} + device::MultiChannelMemory + data::Vector{A} +end + +Base.eltype(::Type{MultiChannelBuffer{A}}) where {A} = eltype(A) +Base.ndims(::Type{MultiChannelBuffer{A}}) where {A} = ndims(A) + +function transfer(device::MultiChannelMemory, arr::AbstractArray) + data = [transfer(device.device, copy(arr)) for _ in 1:(device.n)] + MultiChannelBuffer(device, data) +end +function transfer(device::MultiChannelMemory, arr::MultiChannelBuffer) + data = arr.data + if device.device != arr.device + data = map(buf -> transfer(device.device, buf), data) + end + if arr.device.n > device.n + MultiChannelBuffer(device, data) + else + MultiChannelBuffer(device, vcat(data, [transfer(device, []) for _ in 1:(device.n - arr.device.n)])) + end +end + +function transfer(task::MemoryChannel, arr::MultiChannelBuffer) + if task.device == arr.device + temp = arr.data[task.t] + return temp + else + return arr + end +end +function transfer(dst::MultiChannelBuffer, arr::MultiChannelBuffer) + return arr +end +function transfer(dst::AbstractDevice, arr::MultiChannelBuffer) + if dst == arr.device + return arr + else + data = map(buf -> transfer(device.device, buf), arr.data) + return MultiChannelBuffer(arr.device, data) + end +end + +""" + ShardLevel{Lvl}() + +Each subfiber of a Shard level is stored in a thread-specific tensor of type +`Lvl`, managed by MultiChannelMemory. + +```jldoctest +julia> tensor_tree(Tensor(Dense(Shard(Element(0.0))), [1, 2, 3])) +3-Tensor +└─ Dense [1:3] + ├─ [1]: Shard -> + │ └─ 1.0 + ├─ [2]: Shard -> + │ └─ 2.0 + └─ [3]: Shard -> + └─ 3.0 +``` +""" +struct ShardLevel{Device,Lvl,Ptr,Task,Used,Alloc} <: AbstractLevel + device::Device + lvl::Lvl + ptr::Ptr + task::Task + used::Used + alloc::Alloc +end +const Shard = ShardLevel + +function ShardLevel(device::Device, lvl::Lvl) where {Device,Lvl} + Tp = postype(lvl) + ptr = transfer(shared_memory(device), Tp[]) + task = transfer(shared_memory(device), Tp[]) + used = transfer(shared_memory(device), zeros(Tp, get_num_tasks(device))) + alloc = transfer(shared_memory(device), zeros(Tp, get_num_tasks(device))) + lvl = transfer(MultiChannelMemory(device, get_num_tasks(device)), lvl) + ShardLevel{Device}(device, transfer(MultiChannelMemory(device, get_num_tasks(device)), lvl), ptr, task, used, alloc) +end + +function ShardLevel{Device}( + device, lvl::Lvl, ptr::Ptr, task::Task, used::Used, alloc::Alloc +) where {Device,Lvl,Ptr,Task,Used,Alloc} + ShardLevel{Device,Lvl,Ptr,Task,Used,Alloc}(device, lvl, ptr, task, used, alloc) +end + +function Base.summary(::Shard{Device,Lvl,Ptr,Task,Used,Alloc}) where {Device,Lvl,Ptr,Task,Used,Alloc} + "Shard($(Lvl))" +end + +function similar_level( + lvl::Shard{Device,Lvl,Ptr,Task,Used,Alloc}, fill_value, eltype::Type, dims... +) where {Device,Lvl,Ptr,Task,Used,Alloc} + lvl_2 = similar(lvl.lvl, fill_value, eltype, dims...) + ShardLevel(lvl.device, transfer(MultiChannelMemory(lvl.device, get_num_tasks(lvl.device)), lvl_2)) +end + +function postype(::Type{<:Shard{Device,Lvl,Ptr,Task,Used,Alloc}}) where {Device,Lvl,Ptr,Task,Used,Alloc} + postype(Lvl) +end + +function transfer(device, lvl::ShardLevel) + #lvl_2 = transfer(MultiChannelMemory(lvl.device, get_num_tasks(lvl.device)), lvl.lvl) + lvl_2 = transfer(device, lvl.lvl) #TODO unclear + ptr_2 = transfer(device, lvl.ptr) + task_2 = transfer(device, lvl.task) + qos_used_2 = transfer(device, lvl.used) + qos_alloc_2 = transfer(device, lvl.alloc) + return ShardLevel(lvl_2, ptr_2, task_2, qos_used_2, qos_alloc_2) +end + +function pattern!(lvl::ShardLevel) + ShardLevel(pattern!(lvl.lvl), lvl.ptr, lvl.task, lvl.used, lvl.alloc) +end + +function set_fill_value!(lvl::ShardLevel, init) + ShardLevel( + set_fill_value!(lvl.lvl, init), + lvl.ptr, + lvl.task, + lvl.used, + lvl.alloc, + ) +end + +function Base.resize!(lvl::ShardLevel, dims...) + ShardLevel( + resize!(lvl.lvl, dims...), + lvl.ptr, + lvl.task, + lvl.used, + lvl.alloc, + ) +end + +function Base.show( + io::IO, lvl::ShardLevel{Device,Lvl,Ptr,Task,Used,Alloc} +) where {Device,Lvl,Ptr,Task,Used,Alloc} + print(io, "Shard(") + if get(io, :compact, false) + print(io, "…") + else + show(io, lvl.lvl) + print(io, ", ") + show(io, lvl.ptr) + print(io, ", ") + show(io, lvl.task) + print(io, ", ") + show(io, lvl.used) + print(io, ", ") + show(io, lvl.alloc) + end + print(io, ")") +end + +function labelled_show(io::IO, fbr::SubFiber{<:ShardLevel}) + (lvl, pos) = (fbr.lvl, fbr.pos) + print(io, "shard($(lvl.task[pos])) -> ") +end + +function labelled_children(fbr::SubFiber{<:ShardLevel}) + lvl = fbr.lvl + pos = fbr.pos + pos > length(lvl.ptr) && return [] + lvl_2 = transfer(MemoryChannel(lvl.task[pos], MultiChannelMemory(lvl.device, get_num_tasks(lvl.device)), SerialTask()), lvl.lvl) + [LabelledTree(SubFiber(lvl_2, lvl.ptr[pos]))] +end + +@inline level_ndims( + ::Type{<:ShardLevel{Device,Lvl,Ptr,Task,Used,Alloc}} +) where {Device,Lvl,Ptr,Task,Used,Alloc} = level_ndims(Lvl) +@inline level_size( + lvl::ShardLevel{Device,Lvl,Ptr,Task,Used,Alloc} +) where {Device,Lvl,Ptr,Task,Used,Alloc} = level_size(lvl.lvl) +@inline level_axes( + lvl::ShardLevel{Device,Lvl,Ptr,Task,Used,Alloc} +) where {Device,Lvl,Ptr,Task,Used,Alloc} = level_axes(lvl.lvl) +@inline level_eltype( + ::Type{ShardLevel{Device,Lvl,Ptr,Task,Used,Alloc}} +) where {Device,Lvl,Ptr,Task,Used,Alloc} = level_eltype(Lvl) +@inline level_fill_value( + ::Type{<:ShardLevel{Device,Lvl,Ptr,Task,Used,Alloc}} +) where {Device,Lvl,Ptr,Task,Used,Alloc} = level_fill_value(Lvl) + +function (fbr::SubFiber{<:ShardLevel})(idxs...) + lvl = fbr.lvl + pos = fbr.pos + pos > length(lvl.ptr) && return [] + lvl_2 = transfer(MemoryChannel(lvl.task[pos], MultiChannelMemory(lvl.device, get_num_tasks(lvl.device)), SerialTask()), lvl.lvl) + SubFiber(lvl_2, lvl.ptr[pos])(idxs...) +end + +function countstored_level(lvl::ShardLevel, pos) + sum(1:pos) do qos + lvl_2 = transfer(MemoryChannel(lvl.task[qos], MultiChannelMemory(lvl.device, get_num_tasks(lvl.device)), SerialTask()), lvl.lvl) + countstored_level(lvl_2, lvl.used[qos]) + end +end + +mutable struct VirtualShardLevel <: AbstractVirtualLevel + tag + device + lvl + ptr + task + used + alloc + qos_used + qos_alloc + Tv + Device + Lvl + Ptr + Task + Used + Alloc +end + +postype(lvl::VirtualShardLevel) = postype(lvl.lvl) + +function is_level_injective(ctx, lvl::VirtualShardLevel) + [is_level_injective(ctx, lvl.lvl)..., true] +end +function is_level_atomic(ctx, lvl::VirtualShardLevel) + (below, atomic) = is_level_atomic(ctx, lvl.lvl) + return ([below; [atomic]], atomic) +end +function is_level_concurrent(ctx, lvl::VirtualShardLevel) + (data, _) = is_level_concurrent(ctx, lvl.lvl) + return (data, true) +end + +function lower(ctx::AbstractCompiler, lvl::VirtualShardLevel, ::DefaultStyle) + quote + $ShardLevel( + $(ctx(lvl.device)), + $(ctx(lvl.lvl)), + $(ctx(lvl.ptr)), + $(ctx(lvl.task)), + $(ctx(lvl.used)), + $(ctx(lvl.alloc)) + ) + end +end + +function virtualize( + ctx, ex, ::Type{ShardLevel{Device,Lvl,Ptr,Task,Used,Alloc}}, tag=:lvl +) where {Device,Lvl,Ptr,Task,Used,Alloc} + tag = freshen(ctx, tag) + ptr = freshen(ctx, tag, :_ptr) + task = freshen(ctx, tag, :_task) + used = freshen(ctx, tag, :_qos_used) + alloc = freshen(ctx, tag, :_qos_alloc) + + push_preamble!( + ctx, + quote + $tag = $ex + $ptr = $tag.ptr + $task = $tag.task + $used = $tag.used + $alloc = $tag.alloc + end, + ) + device_2 = virtualize(ctx, :($tag.device), Device, tag) + lvl_2 = virtualize(ctx, :($tag.lvl), Lvl, tag) + VirtualShardLevel(tag, device_2, lvl_2, ptr, task, used, alloc, nothing, nothing, typeof(level_fill_value(Lvl)), Device, Lvl, Ptr, Task, Used, Alloc) +end + +function distribute_level(ctx, lvl::VirtualShardLevel, arch, diff, style) + diff[lvl.tag] = VirtualShardLevel( + lvl.tag, + lvl.device, + distribute_level(ctx, lvl.lvl, arch, diff, style), + distribute_buffer(ctx, lvl.ptr, arch, style), + distribute_buffer(ctx, lvl.task, arch, style), + distribute_buffer(ctx, lvl.used, arch, style), + distribute_buffer(ctx, lvl.alloc, arch, style), + lvl.qos_used, + lvl.qos_alloc, + lvl.Tv, + lvl.Device, + lvl.Lvl, + lvl.Ptr, + lvl.Task, + lvl.Used, + lvl.Alloc, + ) +end + +function distribute_level(ctx, lvl::VirtualShardLevel, arch, diff, style::Union{DeviceShared}) + Tp = postype(lvl) + tag = lvl.tag + if true #get_device(arch) == lvl.device + qos_used = freshen(ctx, tag, :_qos_used) + qos_alloc = freshen(ctx, tag, :_qos_alloc) + tid = ctx(get_task_num(arch)) + push_preamble!(ctx, quote + $qos_used = $(lvl.used)[$tid] + $qos_alloc = $(lvl.alloc)[$tid] + end) + dev = get_device(arch) + multi_channel_dev = VirtualMultiChannelMemory(dev, get_num_tasks(dev)) + channel_task = VirtualMemoryChannel(get_task_num(arch), multi_channel_dev, arch) + lvl_2 = distribute_level(ctx, lvl.lvl, channel_task, diff, style) + lvl_2 = thaw_level!(ctx, lvl_2, value(qos_alloc, Tp)) + push_epilogue!(ctx, contain(ctx) do ctx_2 + quote + $(lvl.used)[$tid] = $qos_used + $(lvl.alloc)[$tid] = $qos_alloc + end + freeze_level!(ctx_2, lvl_2, qos_alloc) + end) + diff[lvl.tag] = VirtualShardLevel( + lvl.tag, + lvl.device, + lvl_2, + distribute_buffer(ctx, lvl.ptr, arch, style), + distribute_buffer(ctx, lvl.task, arch, style), + distribute_buffer(ctx, lvl.used, arch, style), + distribute_buffer(ctx, lvl.alloc, arch, style), + qos_used, + qos_alloc, + lvl.Tv, + lvl.Device, + lvl.Lvl, + lvl.Ptr, + lvl.Task, + lvl.Used, + lvl.Alloc, + ) + else + diff[lvl.tag] = VirtualShardLevel( + lvl.tag, + lvl.device, + distribute_level(ctx, lvl.lvl, arch, diff, style), + distribute_buffer(ctx, lvl.ptr, arch, style), + distribute_buffer(ctx, lvl.task, arch, style), + distribute_buffer(ctx, lvl.used, arch, style), + distribute_buffer(ctx, lvl.alloc, arch, style), + lvl.qos_used, + lvl.qos_alloc, + lvl.Tv, + lvl.Device, + lvl.Lvl, + lvl.Ptr, + lvl.Task, + lvl.Used, + lvl.Alloc, + ) + end +end + +function redistribute(ctx::AbstractCompiler, lvl::VirtualShardLevel, diff) + get( + diff, + lvl.tag, + VirtualShardLevel( + lvl.tag, + lvl.device, + redistribute(ctx, lvl.lvl, diff), + lvl.ptr, + lvl.task, + lvl.used, + lvl.alloc, + lvl.qos_used, + lvl.qos_alloc, + lvl.Tv, + lvl.Device, + lvl.Lvl, + lvl.Ptr, + lvl.Task, + lvl.Used, + lvl.Alloc, + ), + ) +end + +Base.summary(lvl::VirtualShardLevel) = "Shard($(lvl.Lvl))" + +function virtual_level_resize!(ctx, lvl::VirtualShardLevel, dims...) + (lvl.lvl = virtual_level_resize!(ctx, lvl.lvl, dims...); lvl) +end +virtual_level_size(ctx, lvl::VirtualShardLevel) = virtual_level_size(ctx, lvl.lvl) +virtual_level_eltype(lvl::VirtualShardLevel) = virtual_level_eltype(lvl.lvl) +virtual_level_fill_value(lvl::VirtualShardLevel) = virtual_level_fill_value(lvl.lvl) + +function declare_level!(ctx, lvl::VirtualShardLevel, pos, init) + @assert !is_on_device(ctx, lvl.device) + push_preamble!(ctx, contain(ctx) do ctx_2 + diff = Dict() + lvl_2 = distribute_level(ctx_2, lvl.lvl, lvl.device, diff, HostShared()) + used = distribute_buffer(ctx_2, lvl.used, lvl.device, HostShared()) + alloc = distribute_buffer(ctx_2, lvl.alloc, lvl.device, HostShared()) + virtual_parallel_region(ctx_2, lvl.device) do ctx_3 + task = get_task(ctx_3) + multi_channel_dev = VirtualMultiChannelMemory(lvl.device, get_num_tasks(lvl.device)) + channel_task = VirtualMemoryChannel(get_task_num(task), multi_channel_dev, task) + lvl_3 = distribute_level(ctx_3, lvl.lvl, channel_task, diff, DeviceShared()) + used = distribute_buffer(ctx_3, lvl.used, task, DeviceShared()) + alloc = distribute_buffer(ctx_3, lvl.alloc, task, DeviceShared()) + lvl_4 = declare_level!(ctx_3, lvl_3, literal(1), init) + freeze_level!(ctx_3, lvl_4, literal(1)) + tid = ctx_3(get_task_num(ctx_3)) + quote + $(ctx_3(used))[$tid] = 0 + $(ctx_3(alloc))[$tid] = max($(ctx_3(alloc))[$tid], 1) + end + end + end + ) + lvl +end + +function assemble_level!(ctx, lvl::VirtualShardLevel, pos_start, pos_stop) + @assert !is_on_device(ctx, lvl.device) + pos_start = cache!(ctx, :pos_start, simplify(ctx, pos_start)) + pos_stop = cache!(ctx, :pos_stop, simplify(ctx, pos_stop)) + pos = freshen(ctx, :pos) + sym = freshen(ctx, :pointer_to_lvl) + push_preamble!( + ctx, + quote + Finch.resize_if_smaller!($(lvl.task), $(ctx(pos_stop))) + Finch.resize_if_smaller!($(lvl.ptr), $(ctx(pos_stop))) + Finch.fill_range!($(lvl.ptr), 0, $(ctx(pos_start)), $(ctx(pos_stop))) + end, + ) + lvl +end + +supports_reassembly(::VirtualShardLevel) = false + +""" +these two are no-ops, we instead do these on distribute +""" +function freeze_level!(ctx, lvl::VirtualShardLevel, pos) + @assert !is_on_device(ctx, lvl.device) + return lvl +end + +function thaw_level!(ctx::AbstractCompiler, lvl::VirtualShardLevel, pos) + @assert !is_on_device(ctx, lvl.device) + return lvl +end + +function instantiate(ctx, fbr::VirtualSubFiber{VirtualShardLevel}, mode) + (lvl, pos) = (fbr.lvl, fbr.pos) + Tp = postype(lvl) + if mode.kind === reader + tag = lvl.tag + isnulltest = freshen(ctx, tag, :_nulltest) + Vf = level_fill_value(lvl.Lvl) + sym = freshen(ctx, :pointer_to_lvl) + val = freshen(ctx, lvl.tag, :_val) + t = freshen(ctx, tag, :_t) + qos = freshen(ctx, tag, :_q) + push_preamble!(ctx, quote + $t = $(lvl.task)[$(ctx(pos))] + $qos = $(lvl.ptr)[$(ctx(pos))] + end) + task = get_task(ctx) + multi_channel_dev = VirtualMultiChannelMemory(lvl.device, get_num_tasks(lvl.device)) + channel_task = VirtualMemoryChannel(value(t, Tp), multi_channel_dev, task) + lvl_2 = distribute_level(ctx, lvl.lvl, channel_task, Dict(), DeviceGlobal()) + instantiate(ctx, VirtualSubFiber(lvl_2, value(qos, Tp)), mode) + else + @assert is_on_device(ctx, lvl.device) + instantiate(ctx, VirtualHollowSubFiber(lvl, pos, freshen(ctx, :dirty)), mode) + end +end + +""" +assemble: + mapping is pos -> task, ptr. task says which task has it, ptr says which position in that task has it. + +read: + read from pos to task, ptr. simple. + +write: + allocate something for this task on that position, assemble on the task itself on demand. Complain if the task is wrong. + +The outer level needs to be concurrent, like denselevel. +""" +function instantiate(ctx, fbr::VirtualHollowSubFiber{VirtualShardLevel}, mode) + @assert mode.kind === updater + (lvl, pos) = (fbr.lvl, fbr.pos) + tag = lvl.tag + sym = freshen(ctx, :pointer_to_lvl) + Tp = postype(lvl) + + tid = freshen(ctx, tag, :_tid) + qos = freshen(ctx, :qos) + + @assert is_on_device(ctx, lvl.device) + + return Thunk(; + preamble = quote + $qos = $(lvl.ptr)[$(ctx(pos))] + $tid = $(ctx(get_task_num(ctx))) + if $qos == 0 + #this task will always own this position forever, even if we don't write to it. + $qos = $(lvl.qos_used) += 1 + $(lvl.task)[$(ctx(pos))] = $tid + $(lvl.ptr)[$(ctx(pos))] = $(lvl.qos_used) + if $(lvl.qos_used) > $(lvl.qos_alloc) + $(lvl.qos_alloc) = max($(lvl.qos_alloc) << 1, 1) + $(contain(ctx_2 -> assemble_level!(ctx_2, lvl.lvl, value(lvl.qos_used, Tp), value(lvl.qos_alloc, Tp)), ctx)) + end + else + if $(get_mode_flag(ctx) === :safe) + @assert $(lvl.task)[$(ctx(pos))] == $tid "Task mismatch in ShardLevel" + end + end + end, + body = (ctx) -> VirtualHollowSubFiber(lvl.lvl, value(qos), fbr.dirty), + ) +end diff --git a/src/tensors/levels/sparse_band_levels.jl b/src/tensors/levels/sparse_band_levels.jl index a63d9b0ed..c47ddca88 100644 --- a/src/tensors/levels/sparse_band_levels.jl +++ b/src/tensors/levels/sparse_band_levels.jl @@ -153,8 +153,8 @@ mutable struct VirtualSparseBandLevel <: AbstractVirtualLevel lvl Ti shape - qos_fill - qos_stop + qos_used + qos_alloc ros_fill ros_stop dirty @@ -181,8 +181,8 @@ function virtualize( ctx, ex, ::Type{SparseBandLevel{Ti,Idx,Ofs,Lvl}}, tag=:lvl ) where {Ti,Idx,Ofs,Lvl} tag = freshen(ctx, tag) - qos_fill = freshen(ctx, tag, :_qos_fill) - qos_stop = freshen(ctx, tag, :_qos_stop) + qos_used = freshen(ctx, tag, :_qos_used) + qos_alloc = freshen(ctx, tag, :_qos_alloc) ros_fill = freshen(ctx, tag, :_ros_fill) ros_stop = freshen(ctx, tag, :_ros_stop) dirty = freshen(ctx, tag, :_dirty) @@ -206,8 +206,8 @@ function virtualize( lvl_2, Ti, shape, - qos_fill, - qos_stop, + qos_used, + qos_alloc, ros_fill, ros_stop, dirty, @@ -235,8 +235,8 @@ function distribute_level( distribute_level(ctx, lvl.lvl, arch, diff, style), lvl.Ti, lvl.shape, - lvl.qos_fill, - lvl.qos_stop, + lvl.qos_used, + lvl.qos_alloc, lvl.ros_fill, lvl.ros_stop, lvl.dirty, @@ -255,8 +255,8 @@ function redistribute(ctx::AbstractCompiler, lvl::VirtualSparseBandLevel, diff) redistribute(ctx, lvl.lvl, diff), lvl.Ti, lvl.shape, - lvl.qos_fill, - lvl.qos_stop, + lvl.qos_used, + lvl.qos_alloc, lvl.ros_fill, lvl.ros_stop, lvl.dirty, @@ -289,8 +289,8 @@ function declare_level!(ctx::AbstractCompiler, lvl::VirtualSparseBandLevel, pos, push_preamble!( ctx, quote - $(lvl.qos_fill) = $(Tp(0)) - $(lvl.qos_stop) = $(Tp(0)) + $(lvl.qos_used) = $(Tp(0)) + $(lvl.qos_alloc) = $(Tp(0)) Finch.resize_if_smaller!($(lvl.ofs), 1) $(lvl.ofs)[1] = 1 end, @@ -322,7 +322,7 @@ function freeze_level!(ctx::AbstractCompiler, lvl::VirtualSparseBandLevel, pos_s p = freshen(ctx, :p) Tp = postype(lvl) pos_stop = ctx(cache!(ctx, :pos_stop, simplify(ctx, pos_stop))) - qos_stop = freshen(ctx, :qos_stop) + qos_alloc = freshen(ctx, :qos_alloc) push_preamble!( ctx, quote @@ -331,10 +331,10 @@ function freeze_level!(ctx::AbstractCompiler, lvl::VirtualSparseBandLevel, pos_s for $p in 2:($pos_stop + 1) $(lvl.ofs)[$p] += $(lvl.ofs)[$p - 1] end - $qos_stop = $(lvl.ofs)[$pos_stop + 1] - $(Tp(1)) + $qos_alloc = $(lvl.ofs)[$pos_stop + 1] - $(Tp(1)) end, ) - lvl.lvl = freeze_level!(ctx, lvl.lvl, value(qos_stop)) + lvl.lvl = freeze_level!(ctx, lvl.lvl, value(qos_alloc)) return lvl end @@ -418,8 +418,8 @@ function unfurl( qos = freshen(ctx, tag, :_qos) qos_set = freshen(ctx, tag, :_qos_set) ros = freshen(ctx, tag, :_ros) - qos_fill = lvl.qos_fill - qos_stop = lvl.qos_stop + qos_used = lvl.qos_used + qos_alloc = lvl.qos_alloc ros_fill = lvl.ros_fill ros_stop = lvl.ros_stop dirty = freshen(ctx, tag, :dirty) @@ -429,8 +429,8 @@ function unfurl( arr=fbr, body=Thunk(; preamble = quote - $qos = $qos_fill + 1 - $qos_set = $qos_fill + $qos = $qos_used + 1 + $qos_set = $qos_used $my_i_prev = $(Ti(-1)) $my_i_set = $(Ti(-1)) $(if issafe(get_mode_flag(ctx)) @@ -450,14 +450,14 @@ function unfurl( end end end) - $qos = $(ctx(idx)) - $my_i_prev + $qos_fill + 1 + $qos = $(ctx(idx)) - $my_i_prev + $qos_used + 1 end - if $qos > $qos_stop - $qos_2 = $qos_stop + 1 - while $qos > $qos_stop - $qos_stop = max($qos_stop << 1, 1) + if $qos > $qos_alloc + $qos_2 = $qos_alloc + 1 + while $qos > $qos_alloc + $qos_alloc = max($qos_alloc << 1, 1) end - $(contain(ctx_2 -> assemble_level!(ctx_2, lvl.lvl, value(qos_2, Tp), value(qos_stop, Tp)), ctx)) + $(contain(ctx_2 -> assemble_level!(ctx_2, lvl.lvl, value(qos_2, Tp), value(qos_alloc, Tp)), ctx)) end $dirty = false end, @@ -484,7 +484,7 @@ function unfurl( $(lvl.prev_pos) = $(ctx(pos)) end end) - $qos_fill = $qos + $qos_used = $qos end end, ), diff --git a/src/tensors/levels/sparse_bytemap_levels.jl b/src/tensors/levels/sparse_bytemap_levels.jl index fdb82677d..7fd3b0b2d 100644 --- a/src/tensors/levels/sparse_bytemap_levels.jl +++ b/src/tensors/levels/sparse_bytemap_levels.jl @@ -188,8 +188,8 @@ mutable struct VirtualSparseByteMapLevel <: AbstractVirtualLevel tbl srt shape - qos_fill - qos_stop + qos_used + qos_alloc end function is_level_injective(ctx, lvl::VirtualSparseByteMapLevel) @@ -209,8 +209,8 @@ function virtualize( ) where {Ti,Ptr,Tbl,Srt,Lvl} tag = freshen(ctx, tag) shape = value(:($tag.shape), Int) - qos_fill = freshen(ctx, tag, :_qos_fill) - qos_stop = freshen(ctx, tag, :_qos_stop) + qos_used = freshen(ctx, tag, :_qos_used) + qos_alloc = freshen(ctx, tag, :_qos_alloc) ptr = freshen(ctx, tag, :_ptr) tbl = freshen(ctx, tag, :_tbl) srt = freshen(ctx, tag, :_srt) @@ -222,13 +222,13 @@ function virtualize( $ptr = $tag.ptr $tbl = $tag.tbl $srt = $tag.srt - $qos_stop = $qos_fill = length($tag.srt) + $qos_alloc = $qos_used = length($tag.srt) $stop = $tag.shape end, ) shape = value(stop, Int) lvl_2 = virtualize(ctx, :($tag.lvl), Lvl, tag) - VirtualSparseByteMapLevel(tag, lvl_2, Ti, ptr, tbl, srt, shape, qos_fill, qos_stop) + VirtualSparseByteMapLevel(tag, lvl_2, Ti, ptr, tbl, srt, shape, qos_used, qos_alloc) end function lower(ctx::AbstractCompiler, lvl::VirtualSparseByteMapLevel, ::DefaultStyle) quote @@ -253,8 +253,8 @@ function distribute_level( distribute_buffer(ctx, lvl.tbl, arch, style), distribute_buffer(ctx, lvl.srt, arch, style), lvl.shape, - lvl.qos_fill, - lvl.qos_stop, + lvl.qos_used, + lvl.qos_alloc, ) end @@ -270,8 +270,8 @@ function redistribute(ctx::AbstractCompiler, lvl::VirtualSparseByteMapLevel, dif lvl.tbl, lvl.srt, lvl.shape, - lvl.qos_fill, - lvl.qos_stop, + lvl.qos_used, + lvl.qos_alloc, ), ) end @@ -304,7 +304,7 @@ function declare_level!(ctx::AbstractCompiler, lvl::VirtualSparseByteMapLevel, p push_preamble!( ctx, quote - for $r in 1:($(lvl.qos_fill)) + for $r in 1:($(lvl.qos_used)) $p = first($(lvl.srt)[$r]) $(lvl.ptr)[$p] = $(Tp(0)) $(lvl.ptr)[$p + 1] = $(Tp(0)) @@ -319,9 +319,9 @@ function declare_level!(ctx::AbstractCompiler, lvl::VirtualSparseByteMapLevel, p )) end end - $(lvl.qos_fill) = 0 + $(lvl.qos_used) = 0 if $(!supports_reassembly(lvl.lvl)) - $(lvl.qos_stop) = $(Tp(0)) + $(lvl.qos_alloc) = $(Tp(0)) end $(lvl.ptr)[1] = 1 end, @@ -385,10 +385,10 @@ function freeze_level!(ctx::AbstractCompiler, lvl::VirtualSparseByteMapLevel, po quote resize!($(lvl.ptr), $(ctx(pos_stop)) + 1) resize!($(lvl.tbl), $(ctx(pos_stop)) * $(ctx(lvl.shape))) - resize!($(lvl.srt), $(lvl.qos_fill)) + resize!($(lvl.srt), $(lvl.qos_used)) sort!($(lvl.srt)) $p_prev = $(Tp(0)) - for $r in 1:($(lvl.qos_fill)) + for $r in 1:($(lvl.qos_used)) $p = first($(lvl.srt)[$r]) if $p != $p_prev $(lvl.ptr)[$p_prev + 1] = $r @@ -396,8 +396,8 @@ function freeze_level!(ctx::AbstractCompiler, lvl::VirtualSparseByteMapLevel, po end $p_prev = $p end - $(lvl.ptr)[$p_prev + 1] = $(lvl.qos_fill) + 1 - $(lvl.qos_stop) = $(lvl.qos_fill) + $(lvl.ptr)[$p_prev + 1] = $(lvl.qos_used) + 1 + $(lvl.qos_alloc) = $(lvl.qos_used) end, ) lvl.lvl = freeze_level!(ctx, lvl.lvl, call(*, pos_stop, lvl.shape)) @@ -586,12 +586,12 @@ function unfurl( $(fbr.dirty) = true if !$(lvl.tbl)[$my_q] $(lvl.tbl)[$my_q] = true - $(lvl.qos_fill) += 1 - if $(lvl.qos_fill) > $(lvl.qos_stop) - $(lvl.qos_stop) = max($(lvl.qos_stop) << 1, 1) - Finch.resize_if_smaller!($(lvl.srt), $(lvl.qos_stop)) + $(lvl.qos_used) += 1 + if $(lvl.qos_used) > $(lvl.qos_alloc) + $(lvl.qos_alloc) = max($(lvl.qos_alloc) << 1, 1) + Finch.resize_if_smaller!($(lvl.srt), $(lvl.qos_alloc)) end - $(lvl.srt)[$(lvl.qos_fill)] = ($(ctx(pos)), $(ctx(idx))) + $(lvl.srt)[$(lvl.qos_used)] = ($(ctx(pos)), $(ctx(idx))) end end end, diff --git a/src/tensors/levels/sparse_coo_levels.jl b/src/tensors/levels/sparse_coo_levels.jl index 4a014f9d1..8ad73c5c6 100644 --- a/src/tensors/levels/sparse_coo_levels.jl +++ b/src/tensors/levels/sparse_coo_levels.jl @@ -214,8 +214,8 @@ mutable struct VirtualSparseCOOLevel <: AbstractVirtualLevel tbl Lvl shape - qos_fill - qos_stop + qos_used + qos_alloc prev_pos end @@ -235,8 +235,8 @@ function virtualize( ctx, ex, ::Type{SparseCOOLevel{N,TI,Ptr,Tbl,Lvl}}, tag=:lvl ) where {N,TI,Ptr,Tbl,Lvl} tag = freshen(ctx, tag) - qos_fill = freshen(ctx, tag, :_qos_fill) - qos_stop = freshen(ctx, tag, :_qos_stop) + qos_used = freshen(ctx, tag, :_qos_used) + qos_alloc = freshen(ctx, tag, :_qos_alloc) ptr = freshen(ctx, tag, :_ptr) tbl = map(n -> freshen(ctx, tag, :_tbl, n), 1:N) stop = map(n -> freshen(ctx, tag, :_stop, n), 1:N) @@ -261,7 +261,7 @@ function virtualize( prev_pos = freshen(ctx, tag, :_prev_pos) prev_coord = map(n -> freshen(ctx, tag, :_prev_coord_, n), 1:N) VirtualSparseCOOLevel( - tag, lvl_2, N, TI, ptr, tbl, Lvl, shape, qos_fill, qos_stop, prev_pos + tag, lvl_2, N, TI, ptr, tbl, Lvl, shape, qos_used, qos_alloc, prev_pos ) end function lower(ctx::AbstractCompiler, lvl::VirtualSparseCOOLevel, ::DefaultStyle) @@ -287,8 +287,8 @@ function distribute_level( map(idx -> distribute_buffer(ctx, idx, arch, style), lvl.tbl), lvl.Lvl, lvl.shape, - lvl.qos_fill, - lvl.qos_stop, + lvl.qos_used, + lvl.qos_alloc, lvl.prev_pos, ) end @@ -306,8 +306,8 @@ function redistribute(ctx::AbstractCompiler, lvl::VirtualSparseCOOLevel, diff) lvl.tbl, lvl.Lvl, lvl.shape, - lvl.qos_fill, - lvl.qos_stop, + lvl.qos_used, + lvl.qos_alloc, lvl.prev_pos, ), ) @@ -341,8 +341,8 @@ function declare_level!(ctx::AbstractCompiler, lvl::VirtualSparseCOOLevel, pos, push_preamble!( ctx, quote - $(lvl.qos_fill) = $(Tp(0)) - $(lvl.qos_stop) = $(Tp(0)) + $(lvl.qos_used) = $(Tp(0)) + $(lvl.qos_alloc) = $(Tp(0)) end, ) if issafe(get_mode_flag(ctx)) @@ -369,7 +369,7 @@ end function freeze_level!(ctx::AbstractCompiler, lvl::VirtualSparseCOOLevel, pos_stop) p = freshen(ctx, :p) pos_stop = ctx(cache!(ctx, :pos_stop, simplify(ctx, pos_stop))) - qos_stop = freshen(ctx, :qos_stop) + qos_alloc = freshen(ctx, :qos_alloc) push_preamble!( ctx, quote @@ -377,13 +377,13 @@ function freeze_level!(ctx::AbstractCompiler, lvl::VirtualSparseCOOLevel, pos_st for $p in 2:($pos_stop + 1) $(lvl.ptr)[$p] += $(lvl.ptr)[$p - 1] end - $qos_stop = $(lvl.ptr)[$pos_stop + 1] - 1 + $qos_alloc = $(lvl.ptr)[$pos_stop + 1] - 1 $(Expr(:block, map(1:(lvl.N)) do n - :(resize!($(lvl.tbl[n]), $qos_stop)) + :(resize!($(lvl.tbl[n]), $qos_alloc)) end...)) end, ) - lvl.lvl = freeze_level!(ctx, lvl.lvl, value(qos_stop)) + lvl.lvl = freeze_level!(ctx, lvl.lvl, value(qos_alloc)) return lvl end struct SparseCOOWalkTraversal @@ -533,14 +533,14 @@ function unfurl( tag = lvl.tag TI = lvl.TI Tp = postype(lvl) - qos_fill = lvl.qos_fill - qos_stop = lvl.qos_stop + qos_used = lvl.qos_used + qos_alloc = lvl.qos_alloc qos = freshen(ctx, tag, :_q) prev_coord = freshen(ctx, tag, :_prev_coord) Thunk(; preamble = quote - $qos = $qos_fill + 1 + $qos = $qos_used + 1 $(if issafe(get_mode_flag(ctx)) quote $(lvl.prev_pos) < $(ctx(pos)) || throw(FinchProtocolError("SparseCOOLevels cannot be updated multiple times")) @@ -550,15 +550,15 @@ function unfurl( end, body = (ctx) -> unfurl(ctx, SparseCOOExtrudeTraversal(lvl, qos, fbr.dirty, [], prev_coord), ext, mode, proto), epilogue = quote - $(lvl.ptr)[$(ctx(pos)) + 1] = $qos - $qos_fill - 1 + $(lvl.ptr)[$(ctx(pos)) + 1] = $qos - $qos_used - 1 $(if issafe(get_mode_flag(ctx)) quote - if $qos - $qos_fill - 1 > 0 + if $qos - $qos_used - 1 > 0 $(lvl.prev_pos) = $(ctx(pos)) end end end) - $qos_fill = $qos - 1 + $qos_used = $qos - 1 end, ) end @@ -573,8 +573,8 @@ function unfurl( (lvl, qos, fbr_dirty, coords) = (trv.lvl, trv.qos, trv.fbr_dirty, trv.coords) TI = lvl.TI Tp = postype(lvl) - qos_fill = lvl.qos_fill - qos_stop = lvl.qos_stop + qos_used = lvl.qos_used + qos_alloc = lvl.qos_alloc if length(coords) + 1 < lvl.N Lookup(; body=(ctx, i) -> instantiate( @@ -590,12 +590,12 @@ function unfurl( Lookup(; body=(ctx, idx) -> Thunk(; preamble = quote - if $qos > $qos_stop - $qos_stop = max($qos_stop << 1, 1) + if $qos > $qos_alloc + $qos_alloc = max($qos_alloc << 1, 1) $(Expr(:block, map(1:(lvl.N)) do n - :(Finch.resize_if_smaller!($(lvl.tbl[n]), $qos_stop)) + :(Finch.resize_if_smaller!($(lvl.tbl[n]), $qos_alloc)) end...)) - $(contain(ctx_2 -> assemble_level!(ctx_2, lvl.lvl, value(qos, Tp), value(qos_stop, Tp)), ctx)) + $(contain(ctx_2 -> assemble_level!(ctx_2, lvl.lvl, value(qos, Tp), value(qos_alloc, Tp)), ctx)) end $dirty = false end, diff --git a/src/tensors/levels/sparse_dict_levels.jl b/src/tensors/levels/sparse_dict_levels.jl index 38047899f..f19df613b 100644 --- a/src/tensors/levels/sparse_dict_levels.jl +++ b/src/tensors/levels/sparse_dict_levels.jl @@ -220,7 +220,7 @@ mutable struct VirtualSparseDictLevel <: AbstractVirtualLevel tbl pool shape - qos_stop + qos_alloc end function is_level_injective(ctx, lvl::VirtualSparseDictLevel) @@ -244,7 +244,7 @@ function virtualize( val = freshen(ctx, tag, :_val) tbl = freshen(ctx, tag, :_tbl) pool = freshen(ctx, tag, :_pool) - qos_stop = freshen(ctx, tag, :_qos_stop) + qos_alloc = freshen(ctx, tag, :_qos_alloc) stop = freshen(ctx, tag, :_stop) push_preamble!( ctx, @@ -255,13 +255,13 @@ function virtualize( $val = $tag.val $tbl = $tag.tbl $pool = $tag.pool - $qos_stop = length($tbl) + $qos_alloc = length($tbl) $stop = $tag.shape end, ) shape = value(stop, Int) lvl_2 = virtualize(ctx, :($tag.lvl), Lvl, tag) - VirtualSparseDictLevel(tag, lvl_2, Ti, ptr, idx, val, tbl, pool, shape, qos_stop) + VirtualSparseDictLevel(tag, lvl_2, Ti, ptr, idx, val, tbl, pool, shape, qos_alloc) end function lower(ctx::AbstractCompiler, lvl::VirtualSparseDictLevel, ::DefaultStyle) quote @@ -290,7 +290,7 @@ function distribute_level( distribute_buffer(ctx, lvl.tbl, arch, style), lvl.pool, lvl.shape, - lvl.qos_stop, + lvl.qos_alloc, ) end @@ -308,7 +308,7 @@ function redistribute(ctx::AbstractCompiler, lvl::VirtualSparseDictLevel, diff) lvl.tbl, lvl.pool, lvl.shape, - lvl.qos_stop, + lvl.qos_alloc, ), ) end @@ -342,7 +342,7 @@ function declare_level!(ctx::AbstractCompiler, lvl::VirtualSparseDictLevel, pos, empty!($(lvl.tbl)) empty!($(lvl.pool)) $qos = $(Tp(0)) - $(lvl.qos_stop) = 0 + $(lvl.qos_alloc) = 0 end, ) lvl.lvl = declare_level!(ctx, lvl.lvl, value(qos, Tp), init) @@ -359,7 +359,7 @@ function freeze_level!(ctx::AbstractCompiler, lvl::VirtualSparseDictLevel, pos_s Tp = postype(lvl) Ti = lvl.Ti pos_stop = cache!(ctx, :pos_stop, simplify(ctx, pos_stop)) - qos_stop = freshen(ctx, :qos_stop) + qos_alloc = freshen(ctx, :qos_alloc) p = freshen(ctx, :p) q = freshen(ctx, :q) r = freshen(ctx, :r) @@ -403,10 +403,10 @@ function freeze_level!(ctx::AbstractCompiler, lvl::VirtualSparseDictLevel, pos_s $(lvl.val)[$r] = $val_tmp[$q] $ptr_2[$p] += 1 end - $qos_stop = $(lvl.ptr)[$(ctx(pos_stop)) + 1] - 1 + $qos_alloc = $(lvl.ptr)[$(ctx(pos_stop)) + 1] - 1 end, ) - lvl.lvl = freeze_level!(ctx, lvl.lvl, value(qos_stop)) + lvl.lvl = freeze_level!(ctx, lvl.lvl, value(qos_alloc)) return lvl end @@ -416,10 +416,10 @@ function thaw_level!(ctx::AbstractCompiler, lvl::VirtualSparseDictLevel, pos_sto push_preamble!( ctx, quote - $(lvl.qos_stop) = $(lvl.ptr)[$(ctx(pos_stop)) + 1] - 1 + $(lvl.qos_alloc) = $(lvl.ptr)[$(ctx(pos_stop)) + 1] - 1 end, ) - lvl.lvl = thaw_level!(ctx, lvl.lvl, value(lvl.qos_stop)) + lvl.lvl = thaw_level!(ctx, lvl.lvl, value(lvl.qos_alloc)) return lvl end @@ -528,7 +528,7 @@ function unfurl( tag = lvl.tag Tp = postype(lvl) qos = freshen(ctx, tag, :_qos) - qos_stop = lvl.qos_stop + qos_alloc = lvl.qos_alloc dirty = freshen(ctx, tag, :_dirty) Thunk(; @@ -544,11 +544,11 @@ function unfurl( $qos = pop!($(lvl.pool)) else $qos = length($(lvl.tbl)) + 1 - if $qos > $qos_stop - $qos_stop = max($qos_stop << 1, 1) - $(contain(ctx_2 -> assemble_level!(ctx_2, lvl.lvl, value(qos, Tp), value(qos_stop, Tp)), ctx)) - Finch.resize_if_smaller!($(lvl.val), $qos_stop) - Finch.fill_range!($(lvl.val), 0, $qos, $qos_stop) + if $qos > $qos_alloc + $qos_alloc = max($qos_alloc << 1, 1) + $(contain(ctx_2 -> assemble_level!(ctx_2, lvl.lvl, value(qos, Tp), value(qos_alloc, Tp)), ctx)) + Finch.resize_if_smaller!($(lvl.val), $qos_alloc) + Finch.fill_range!($(lvl.val), 0, $qos, $qos_alloc) end end $(lvl.tbl)[($(ctx(pos)), $(ctx(idx)))] = $qos diff --git a/src/tensors/levels/sparse_interval_levels.jl b/src/tensors/levels/sparse_interval_levels.jl index 6fd825028..bcd0812aa 100644 --- a/src/tensors/levels/sparse_interval_levels.jl +++ b/src/tensors/levels/sparse_interval_levels.jl @@ -185,8 +185,8 @@ mutable struct VirtualSparseIntervalLevel <: AbstractVirtualLevel left right shape - qos_fill - qos_stop + qos_used + qos_alloc prev_pos end @@ -220,11 +220,11 @@ function virtualize( ) shape = value(stop, Int) lvl_2 = virtualize(ctx, :($tag.lvl), Lvl, tag) - qos_fill = freshen(ctx, tag, :_qos_fill) - qos_stop = freshen(ctx, tag, :_qos_stop) + qos_used = freshen(ctx, tag, :_qos_used) + qos_alloc = freshen(ctx, tag, :_qos_alloc) prev_pos = freshen(ctx, tag, :_prev_pos) VirtualSparseIntervalLevel( - tag, lvl_2, Ti, left, right, shape, qos_fill, qos_stop, prev_pos + tag, lvl_2, Ti, left, right, shape, qos_used, qos_alloc, prev_pos ) end function lower(ctx::AbstractCompiler, lvl::VirtualSparseIntervalLevel, ::DefaultStyle) @@ -246,8 +246,8 @@ function distribute_level(ctx, lvl::VirtualSparseIntervalLevel, arch, diff, styl distribute_buffer(ctx, lvl.left, arch, style), distribute_buffer(ctx, lvl.right, arch, style), lvl.shape, - lvl.qos_fill, - lvl.qos_stop, + lvl.qos_used, + lvl.qos_alloc, lvl.prev_pos, ) end @@ -263,8 +263,8 @@ function redistribute(ctx::AbstractCompiler, lvl::VirtualSparseIntervalLevel, di lvl.left, lvl.right, lvl.shape, - lvl.qos_fill, - lvl.qos_stop, + lvl.qos_used, + lvl.qos_alloc, lvl.prev_pos, ), ) diff --git a/src/tensors/levels/sparse_list_levels.jl b/src/tensors/levels/sparse_list_levels.jl index e8a6a4c3f..22ce8d824 100644 --- a/src/tensors/levels/sparse_list_levels.jl +++ b/src/tensors/levels/sparse_list_levels.jl @@ -166,8 +166,8 @@ mutable struct VirtualSparseListLevel <: AbstractVirtualLevel ptr idx shape - qos_fill - qos_stop + qos_used + qos_alloc prev_pos end @@ -201,11 +201,11 @@ function virtualize( ) shape = value(stop, Int) lvl_2 = virtualize(ctx, :($tag.lvl), Lvl, tag) - qos_fill = freshen(ctx, tag, :_qos_fill) - qos_stop = freshen(ctx, tag, :_qos_stop) + qos_used = freshen(ctx, tag, :_qos_used) + qos_alloc = freshen(ctx, tag, :_qos_alloc) prev_pos = freshen(ctx, tag, :_prev_pos) VirtualSparseListLevel( - tag, lvl_2, Ti, ptr, idx, shape, qos_fill, qos_stop, prev_pos + tag, lvl_2, Ti, ptr, idx, shape, qos_used, qos_alloc, prev_pos ) end function lower(ctx::AbstractCompiler, lvl::VirtualSparseListLevel, ::DefaultStyle) @@ -229,8 +229,8 @@ function distribute_level( distribute_buffer(ctx, lvl.ptr, arch, style), distribute_buffer(ctx, lvl.idx, arch, style), lvl.shape, - lvl.qos_fill, - lvl.qos_stop, + lvl.qos_used, + lvl.qos_alloc, lvl.prev_pos, ) end @@ -246,8 +246,8 @@ function redistribute(ctx::AbstractCompiler, lvl::VirtualSparseListLevel, diff) lvl.ptr, lvl.idx, lvl.shape, - lvl.qos_fill, - lvl.qos_stop, + lvl.qos_used, + lvl.qos_alloc, lvl.prev_pos, ), ) @@ -278,8 +278,8 @@ function declare_level!(ctx::AbstractCompiler, lvl::VirtualSparseListLevel, pos, push_preamble!( ctx, quote - $(lvl.qos_fill) = $(Tp(0)) - $(lvl.qos_stop) = $(Tp(0)) + $(lvl.qos_used) = $(Tp(0)) + $(lvl.qos_alloc) = $(Tp(0)) end, ) if issafe(get_mode_flag(ctx)) @@ -306,7 +306,7 @@ end function freeze_level!(ctx::AbstractCompiler, lvl::VirtualSparseListLevel, pos_stop) p = freshen(ctx, :p) pos_stop = ctx(cache!(ctx, :pos_stop, simplify(ctx, pos_stop))) - qos_stop = freshen(ctx, :qos_stop) + qos_alloc = freshen(ctx, :qos_alloc) push_preamble!( ctx, quote @@ -314,30 +314,30 @@ function freeze_level!(ctx::AbstractCompiler, lvl::VirtualSparseListLevel, pos_s for $p in 1:($pos_stop) $(lvl.ptr)[$p + 1] += $(lvl.ptr)[$p] end - $qos_stop = $(lvl.ptr)[$pos_stop + 1] - 1 - resize!($(lvl.idx), $qos_stop) + $qos_alloc = $(lvl.ptr)[$pos_stop + 1] - 1 + resize!($(lvl.idx), $qos_alloc) end, ) - lvl.lvl = freeze_level!(ctx, lvl.lvl, value(qos_stop)) + lvl.lvl = freeze_level!(ctx, lvl.lvl, value(qos_alloc)) return lvl end function thaw_level!(ctx::AbstractCompiler, lvl::VirtualSparseListLevel, pos_stop) p = freshen(ctx, :p) pos_stop = ctx(cache!(ctx, :pos_stop, simplify(ctx, pos_stop))) - qos_stop = freshen(ctx, :qos_stop) + qos_alloc = freshen(ctx, :qos_alloc) push_preamble!( ctx, quote - $(lvl.qos_fill) = $(lvl.ptr)[$pos_stop + 1] - 1 - $(lvl.qos_stop) = $(lvl.qos_fill) - $qos_stop = $(lvl.qos_fill) + $(lvl.qos_used) = $(lvl.ptr)[$pos_stop + 1] - 1 + $(lvl.qos_alloc) = $(lvl.qos_used) + $qos_alloc = $(lvl.qos_used) $( if issafe(get_mode_flag(ctx)) quote $(lvl.prev_pos) = Finch.scansearch( - $(lvl.ptr), $(lvl.qos_stop) + 1, 1, $pos_stop + $(lvl.ptr), $(lvl.qos_alloc) + 1, 1, $pos_stop ) - 1 end end @@ -347,7 +347,7 @@ function thaw_level!(ctx::AbstractCompiler, lvl::VirtualSparseListLevel, pos_sto end end, ) - lvl.lvl = thaw_level!(ctx, lvl.lvl, value(qos_stop)) + lvl.lvl = thaw_level!(ctx, lvl.lvl, value(qos_alloc)) return lvl end @@ -508,13 +508,13 @@ function unfurl( tag = lvl.tag Tp = postype(lvl) qos = freshen(ctx, tag, :_qos) - qos_fill = lvl.qos_fill - qos_stop = lvl.qos_stop + qos_used = lvl.qos_used + qos_alloc = lvl.qos_alloc dirty = freshen(ctx, tag, :dirty) Thunk(; preamble = quote - $qos = $qos_fill + 1 + $qos = $qos_used + 1 $(if issafe(get_mode_flag(ctx)) quote $(lvl.prev_pos) < $(ctx(pos)) || throw(FinchProtocolError("SparseListLevels cannot be updated multiple times")) @@ -524,10 +524,10 @@ function unfurl( body = (ctx) -> Lookup(; body=(ctx, idx) -> Thunk(; preamble = quote - if $qos > $qos_stop - $qos_stop = max($qos_stop << 1, 1) - Finch.resize_if_smaller!($(lvl.idx), $qos_stop) - $(contain(ctx_2 -> assemble_level!(ctx_2, lvl.lvl, value(qos, Tp), value(qos_stop, Tp)), ctx)) + if $qos > $qos_alloc + $qos_alloc = max($qos_alloc << 1, 1) + Finch.resize_if_smaller!($(lvl.idx), $qos_alloc) + $(contain(ctx_2 -> assemble_level!(ctx_2, lvl.lvl, value(qos, Tp), value(qos_alloc, Tp)), ctx)) end $dirty = false end, @@ -547,8 +547,8 @@ function unfurl( ) ), epilogue = quote - $(lvl.ptr)[$(ctx(pos)) + 1] += $qos - $qos_fill - 1 - $qos_fill = $qos - 1 + $(lvl.ptr)[$(ctx(pos)) + 1] += $qos - $qos_used - 1 + $qos_used = $qos - 1 end, ) end diff --git a/src/tensors/levels/sparse_rle_levels.jl b/src/tensors/levels/sparse_rle_levels.jl index 8ddc4c425..39e95965c 100644 --- a/src/tensors/levels/sparse_rle_levels.jl +++ b/src/tensors/levels/sparse_rle_levels.jl @@ -222,8 +222,8 @@ mutable struct VirtualSparseRunListLevel <: AbstractVirtualLevel lvl Ti shape - qos_fill - qos_stop + qos_used + qos_alloc ptr left right @@ -250,8 +250,8 @@ function virtualize( ctx, ex, ::Type{SparseRunListLevel{Ti,Ptr,Left,Right,merge,Lvl}}, tag=:lvl ) where {Ti,Ptr,Left,Right,merge,Lvl} tag = freshen(ctx, tag) - qos_fill = freshen(ctx, tag, :_qos_fill) - qos_stop = freshen(ctx, tag, :_qos_stop) + qos_used = freshen(ctx, tag, :_qos_used) + qos_alloc = freshen(ctx, tag, :_qos_alloc) dirty = freshen(ctx, tag, :_dirty) ptr = freshen(ctx, tag, :_ptr) left = freshen(ctx, tag, :_left) @@ -274,7 +274,7 @@ function virtualize( lvl_2 = virtualize(ctx, :($tag.lvl), Lvl, tag) buf = virtualize(ctx, :($tag.buf), Lvl, tag) VirtualSparseRunListLevel( - tag, lvl_2, Ti, shape, qos_fill, qos_stop, ptr, left, right, buf, merge, + tag, lvl_2, Ti, shape, qos_used, qos_alloc, ptr, left, right, buf, merge, prev_pos, ) end @@ -300,8 +300,8 @@ function distribute_level( distribute_level(ctx, lvl.lvl, arch, diff, style), lvl.Ti, lvl.shape, - lvl.qos_fill, - lvl.qos_stop, + lvl.qos_used, + lvl.qos_alloc, distribute_buffer(ctx, lvl.ptr, arch, style), distribute_buffer(ctx, lvl.left, arch, style), distribute_buffer(ctx, lvl.right, arch, style), @@ -319,8 +319,8 @@ function redistribute(ctx::AbstractCompiler, lvl::VirtualSparseRunListLevel, dif lvl.tag, redistribute(ctx, lvl.lvl, diff), lvl.Ti, - lvl.qos_fill, - lvl.qos_stop, + lvl.qos_used, + lvl.qos_alloc, lvl.ptr, lvl.left, lvl.right, @@ -355,8 +355,8 @@ function declare_level!(ctx::AbstractCompiler, lvl::VirtualSparseRunListLevel, p push_preamble!( ctx, quote - $(lvl.qos_fill) = $(Tp(0)) - $(lvl.qos_stop) = $(Tp(0)) + $(lvl.qos_used) = $(Tp(0)) + $(lvl.qos_alloc) = $(Tp(0)) end, ) if issafe(get_mode_flag(ctx)) @@ -385,17 +385,17 @@ function freeze_level!(ctx::AbstractCompiler, lvl::VirtualSparseRunListLevel, po (lvl.buf, lvl.lvl) = (lvl.lvl, lvl.buf) p = freshen(ctx, :p) pos_stop = ctx(cache!(ctx, :pos_stop, simplify(ctx, pos_stop))) - qos_stop = freshen(ctx, :qos_stop) + qos_alloc = freshen(ctx, :qos_alloc) push_preamble!(ctx, quote resize!($(lvl.ptr), $pos_stop + 1) for $p = 1:$pos_stop $(lvl.ptr)[$p + 1] += $(lvl.ptr)[$p] end - $qos_stop = $(lvl.ptr)[$pos_stop + 1] - 1 - resize!($(lvl.left), $qos_stop) - resize!($(lvl.right), $qos_stop) + $qos_alloc = $(lvl.ptr)[$pos_stop + 1] - 1 + resize!($(lvl.left), $qos_alloc) + resize!($(lvl.right), $qos_alloc) end) - lvl.lvl = freeze_level!(ctx, lvl.lvl, value(qos_stop)) + lvl.lvl = freeze_level!(ctx, lvl.lvl, value(qos_alloc)) return lvl end =# @@ -404,7 +404,7 @@ function freeze_level!(ctx::AbstractCompiler, lvl::VirtualSparseRunListLevel, po Tp = postype(lvl) p = freshen(ctx, :p) pos_stop = ctx(cache!(ctx, :pos_stop, simplify(ctx, pos_stop))) - qos_stop = freshen(ctx, :qos_stop) + qos_alloc = freshen(ctx, :qos_alloc) push_preamble!( ctx, quote @@ -412,11 +412,11 @@ function freeze_level!(ctx::AbstractCompiler, lvl::VirtualSparseRunListLevel, po for $p in 1:($pos_stop) $(lvl.ptr)[$p + 1] += $(lvl.ptr)[$p] end - $qos_stop = $(lvl.ptr)[$pos_stop + 1] - 1 + $qos_alloc = $(lvl.ptr)[$pos_stop + 1] - 1 end, ) if lvl.merge - lvl.buf = freeze_level!(ctx, lvl.buf, value(qos_stop)) + lvl.buf = freeze_level!(ctx, lvl.buf, value(qos_alloc)) lvl.lvl = declare_level!( ctx, lvl.lvl, literal(1), literal(virtual_level_fill_value(lvl.buf)) ) @@ -432,7 +432,7 @@ function freeze_level!(ctx::AbstractCompiler, lvl::VirtualSparseRunListLevel, po quote $(contain( ctx_2 -> - assemble_level!(ctx_2, lvl.lvl, value(1, Tp), value(qos_stop, Tp)), + assemble_level!(ctx_2, lvl.lvl, value(1, Tp), value(qos_alloc, Tp)), ctx, )) $q = 1 @@ -533,10 +533,10 @@ function freeze_level!(ctx::AbstractCompiler, lvl::VirtualSparseRunListLevel, po end resize!($(lvl.left), $q_2 - 1) resize!($(lvl.right), $q_2 - 1) - $qos_stop = $q_2 - 1 + $qos_alloc = $q_2 - 1 end, ) - lvl.lvl = freeze_level!(ctx, lvl.lvl, value(qos_stop)) + lvl.lvl = freeze_level!(ctx, lvl.lvl, value(qos_alloc)) lvl.buf = declare_level!( ctx, lvl.buf, literal(1), literal(virtual_level_fill_value(lvl.buf)) ) @@ -546,12 +546,12 @@ function freeze_level!(ctx::AbstractCompiler, lvl::VirtualSparseRunListLevel, po push_preamble!( ctx, quote - resize!($(lvl.left), $qos_stop) - resize!($(lvl.right), $qos_stop) + resize!($(lvl.left), $qos_alloc) + resize!($(lvl.right), $qos_alloc) end, ) (lvl.lvl, lvl.buf) = (lvl.buf, lvl.lvl) - lvl.lvl = freeze_level!(ctx, lvl.lvl, value(qos_stop)) + lvl.lvl = freeze_level!(ctx, lvl.lvl, value(qos_alloc)) return lvl end end @@ -559,19 +559,19 @@ end function thaw_level!(ctx::AbstractCompiler, lvl::VirtualSparseRunListLevel, pos_stop) p = freshen(ctx, :p) pos_stop = ctx(cache!(ctx, :pos_stop, simplify(ctx, pos_stop))) - qos_stop = freshen(ctx, :qos_stop) + qos_alloc = freshen(ctx, :qos_alloc) push_preamble!( ctx, quote - $(lvl.qos_fill) = $(lvl.ptr)[$pos_stop + 1] - 1 - $(lvl.qos_stop) = $(lvl.qos_fill) - $qos_stop = $(lvl.qos_fill) + $(lvl.qos_used) = $(lvl.ptr)[$pos_stop + 1] - 1 + $(lvl.qos_alloc) = $(lvl.qos_used) + $qos_alloc = $(lvl.qos_used) $( if issafe(get_mode_flag(ctx)) quote $(lvl.prev_pos) = Finch.scansearch( - $(lvl.ptr), $(lvl.qos_stop) + 1, 1, $pos_stop + $(lvl.ptr), $(lvl.qos_alloc) + 1, 1, $pos_stop ) - 1 end end @@ -582,7 +582,7 @@ function thaw_level!(ctx::AbstractCompiler, lvl::VirtualSparseRunListLevel, pos_ end, ) (lvl.lvl, lvl.buf) = (lvl.buf, lvl.lvl) - lvl.buf = thaw_level!(ctx, lvl.buf, value(qos_stop)) + lvl.buf = thaw_level!(ctx, lvl.buf, value(qos_alloc)) return lvl end @@ -676,13 +676,13 @@ function unfurl( Tp = postype(lvl) Ti = lvl.Ti qos = freshen(ctx, tag, :_qos) - qos_fill = lvl.qos_fill - qos_stop = lvl.qos_stop + qos_used = lvl.qos_used + qos_alloc = lvl.qos_alloc dirty = freshen(ctx, tag, :dirty) Thunk(; preamble=quote - $qos = $qos_fill + 1 + $qos = $qos_used + 1 $( if issafe(get_mode_flag(ctx)) quote @@ -698,11 +698,11 @@ function unfurl( body=(ctx) -> AcceptRun(; body=(ctx, ext) -> Thunk(; preamble = quote - if $qos > $qos_stop - $qos_stop = max($qos_stop << 1, 1) - Finch.resize_if_smaller!($(lvl.left), $qos_stop) - Finch.resize_if_smaller!($(lvl.right), $qos_stop) - $(contain(ctx_2 -> assemble_level!(ctx_2, lvl.buf, value(qos, Tp), value(qos_stop, Tp)), ctx)) + if $qos > $qos_alloc + $qos_alloc = max($qos_alloc << 1, 1) + Finch.resize_if_smaller!($(lvl.left), $qos_alloc) + Finch.resize_if_smaller!($(lvl.right), $qos_alloc) + $(contain(ctx_2 -> assemble_level!(ctx_2, lvl.buf, value(qos, Tp), value(qos_alloc, Tp)), ctx)) end $dirty = false end, @@ -723,8 +723,8 @@ function unfurl( ), ), epilogue=quote - $(lvl.ptr)[$(ctx(pos)) + 1] += $qos - $qos_fill - 1 - $qos_fill = $qos - 1 + $(lvl.ptr)[$(ctx(pos)) + 1] += $qos - $qos_used - 1 + $qos_used = $qos - 1 end, ) end diff --git a/src/tensors/levels/sparse_vbl_levels.jl b/src/tensors/levels/sparse_vbl_levels.jl index 570c92ff8..3fd123db2 100644 --- a/src/tensors/levels/sparse_vbl_levels.jl +++ b/src/tensors/levels/sparse_vbl_levels.jl @@ -195,8 +195,8 @@ mutable struct VirtualSparseBlockListLevel <: AbstractVirtualLevel lvl Ti shape - qos_fill - qos_stop + qos_used + qos_alloc ros_fill ros_stop dirty @@ -223,8 +223,8 @@ function virtualize( ctx, ex, ::Type{SparseBlockListLevel{Ti,Ptr,Idx,Ofs,Lvl}}, tag=:lvl ) where {Ti,Ptr,Idx,Ofs,Lvl} tag = freshen(ctx, tag) - qos_fill = freshen(ctx, tag, :_qos_fill) - qos_stop = freshen(ctx, tag, :_qos_stop) + qos_used = freshen(ctx, tag, :_qos_used) + qos_alloc = freshen(ctx, tag, :_qos_alloc) ros_fill = freshen(ctx, tag, :_ros_fill) ros_stop = freshen(ctx, tag, :_ros_stop) dirty = freshen(ctx, tag, :_dirty) @@ -250,8 +250,8 @@ function virtualize( lvl_2, Ti, shape, - qos_fill, - qos_stop, + qos_used, + qos_alloc, ros_fill, ros_stop, dirty, @@ -281,8 +281,8 @@ function distribute_level( distribute_level(ctx, lvl.lvl, arch, diff, style), lvl.Ti, lvl.shape, - lvl.qos_fill, - lvl.qos_stop, + lvl.qos_used, + lvl.qos_alloc, lvl.ros_fill, lvl.ros_stop, lvl.dirty, @@ -301,8 +301,8 @@ function redistribute(ctx::AbstractCompiler, lvl::VirtualSparseBlockListLevel, d lvl.tag, redistribute(ctx, lvl.lvl, diff), lvl.Ti, - lvl.qos_fill, - lvl.qos_stop, + lvl.qos_used, + lvl.qos_alloc, lvl.ros_fill, lvl.ros_stop, lvl.dirty, @@ -338,8 +338,8 @@ function declare_level!(ctx::AbstractCompiler, lvl::VirtualSparseBlockListLevel, push_preamble!( ctx, quote - $(lvl.qos_fill) = $(Tp(0)) - $(lvl.qos_stop) = $(Tp(0)) + $(lvl.qos_used) = $(Tp(0)) + $(lvl.qos_alloc) = $(Tp(0)) $(lvl.ros_fill) = $(Tp(0)) $(lvl.ros_stop) = $(Tp(0)) Finch.resize_if_smaller!($(lvl.ofs), 1) @@ -372,7 +372,7 @@ function freeze_level!(ctx::AbstractCompiler, lvl::VirtualSparseBlockListLevel, Tp = postype(lvl) pos_stop = ctx(cache!(ctx, :pos_stop, simplify(ctx, pos_stop))) ros_stop = freshen(ctx, :ros_stop) - qos_stop = freshen(ctx, :qos_stop) + qos_alloc = freshen(ctx, :qos_alloc) push_preamble!( ctx, quote @@ -383,10 +383,10 @@ function freeze_level!(ctx::AbstractCompiler, lvl::VirtualSparseBlockListLevel, $ros_stop = $(lvl.ptr)[$pos_stop + 1] - 1 resize!($(lvl.idx), $ros_stop) resize!($(lvl.ofs), $ros_stop + 1) - $qos_stop = $(lvl.ofs)[$ros_stop + 1] - $(Tp(1)) + $qos_alloc = $(lvl.ofs)[$ros_stop + 1] - $(Tp(1)) end, ) - lvl.lvl = freeze_level!(ctx, lvl.lvl, value(qos_stop)) + lvl.lvl = freeze_level!(ctx, lvl.lvl, value(qos_alloc)) return lvl end @@ -562,8 +562,8 @@ function unfurl( my_i_prev = freshen(ctx, tag, :_i_prev) qos = freshen(ctx, tag, :_qos) ros = freshen(ctx, tag, :_ros) - qos_fill = lvl.qos_fill - qos_stop = lvl.qos_stop + qos_used = lvl.qos_used + qos_alloc = lvl.qos_alloc ros_fill = lvl.ros_fill ros_stop = lvl.ros_stop dirty = freshen(ctx, tag, :dirty) @@ -571,7 +571,7 @@ function unfurl( Thunk(; preamble = quote $ros = $ros_fill - $qos = $qos_fill + 1 + $qos = $qos_used + 1 $my_i_prev = $(Ti(-1)) $(if issafe(get_mode_flag(ctx)) quote @@ -582,9 +582,9 @@ function unfurl( body = (ctx) -> Lookup(; body=(ctx, idx) -> Thunk(; preamble = quote - if $qos > $qos_stop - $qos_stop = max($qos_stop << 1, 1) - $(contain(ctx_2 -> assemble_level!(ctx_2, lvl.lvl, value(qos, Tp), value(qos_stop, Tp)), ctx)) + if $qos > $qos_alloc + $qos_alloc = max($qos_alloc << 1, 1) + $(contain(ctx_2 -> assemble_level!(ctx_2, lvl.lvl, value(qos, Tp), value(qos_alloc, Tp)), ctx)) end $dirty = false end, @@ -615,7 +615,7 @@ function unfurl( epilogue = quote $(lvl.ptr)[$(ctx(pos)) + 1] = $ros - $ros_fill $ros_fill = $ros - $qos_fill = $qos - 1 + $qos_used = $qos - 1 end, ) end From 2b390447a3dfdc87a12b42ba01dcc243afd4931f Mon Sep 17 00:00:00 2001 From: Willow Ahrens Date: Wed, 9 Apr 2025 19:45:41 -0400 Subject: [PATCH 2/2] initial plan --- src/scheduler/optimize.jl | 95 ++++++++++++++++++++++----------------- 1 file changed, 55 insertions(+), 40 deletions(-) diff --git a/src/scheduler/optimize.jl b/src/scheduler/optimize.jl index dd7db4860..1e81cea59 100644 --- a/src/scheduler/optimize.jl +++ b/src/scheduler/optimize.jl @@ -339,6 +339,11 @@ function propagate_into_reformats(root) ) end +""" + issubsequence(a, b) + +Returns true if `a` is a subsequence of `b`. +""" function issubsequence(a, b) a = collect(a) b = collect(b) @@ -616,52 +621,62 @@ function normalize_names(ex) Rewrite(Postwalk(@rule ~a::isalias => alias(normname(a.name))))(ex) end -function toposort(chains::Vector{Vector{T}}) where {T} - chains = filter(!isempty, deepcopy(chains)) - parents = Dict{T,Int}(map(chain -> first(chain) => 0, chains)) - for chain in chains, u in chain[2:end] - parents[u] += 1 - end - roots = filter(u -> parents[u] == 0, keys(parents)) - perm = [] - while !isempty(parents) - isempty(roots) && return nothing - push!(perm, pop!(roots)) - for chain in chains - if !isempty(chain) && first(chain) == last(perm) - popfirst!(chain) - if !isempty(chain) - parents[first(chain)] -= 1 - if parents[first(chain)] == 0 - push!(roots, first(chain)) - end - end - end - end - pop!(parents, last(perm)) - end - return perm -end +""" + heuristic_loop_order(node, reps, reduced = []) -function heuristic_loop_order(node, reps) - chains = Vector{LogicNode}[] +Heuristically determine a loop order for a reduction. +TODO could be made way better by considering sparsity +""" +function heuristic_loop_order(node, reps, reduced = []) + swizzles = Vector{LogicNode}[] for node in PostOrderDFS(node) if @capture node reorder(relabel(~arg, ~idxs...), ~idxs_2...) - push!(chains, intersect(idxs, idxs_2)) + push!(swizzles, intersect(idxs, idxs_2)) end end - for idx in getfields(node) - push!(chains, [idx]) - end - res = something(toposort(chains), getfields(node)) - if mapreduce(length, max, chains; init=0) < length(unique(reduce(vcat, chains))) - counts = Dict() - for chain in chains - for idx in chain - counts[idx] = get(counts, idx, 0) + 1 - end + + #The maximum number of dimensions of any tensor + max_num_dims = mapreduce(length, max, swizzles; init=0) + #The number of loops over at least one of the arguments + num_compute_loops = length(unique(reduce(vcat, swizzles))) + #The number of times each index occurs in an access + idx_counts = countmap(reduce(vcat, swizzles)) + #At least one access containing each index + idx_swizzles = Dict([idx => swizzle for idx in swizzle for swizzle in swizzles]) + best_score = nothing + res = nothing + for order in permutations(getfields(node)) + #A list of the transposed accesses + transposed = filter(swizzle -> !issubsequence(swizzle, order), swizzles) + #The number of dimensions in the largest transpose + biggest_transpose = mapreduce(length, max, transposed; init=0) + #The number of transposes which are as big as the biggest transpose + num_big_transposes = count(swizzle -> length(swizzle) == biggest_transpose, transposed) + #If the number of indices in the loop nest is bigger than the number of + #indices on any one argument, transposition should be asymptotically + #dominated by the cost of the whole expression. In this case, we can change + #the loop order to improve performance. + if num_compute_loops > max_num_dims + #The loop depth of the last index shared between more than one access + last_shared_idx_depth = something(findlast(idx -> idx_counts[idx] > 1, order), 0) + #We say an outer product (two indices over unrelated tensors) is + #unguarded if it a shared index is nested within it. Ideally, we would + #guard the outer product with the shared index. + unguarded_outer_product = length(unique([idx_swizzles[idx] for idx in order[1:last_shared_idx_depth] if idx_counts[idx] == 1])) > 1 + #Because the output index order is set by the loop order, we can + #only scatter during reductions. A scattered dimensions is an unreduced dimension + #nested inside a reduced dimension. + num_scattered_dims = count(idx -> !in(idx, reduced), order) - something(findfirst(idx -> in(idx, reduced), order), length(order) + 1) - 1 + #When there is one scattered dimension, we can use gustavson's algorithm. + nontrivial_scatter = num_scattered_dims > 1 + score = (unguarded_outer_product, nontrivial_scatter, biggest_transpose, num_big_transposes) + else + score = (biggest_transpose, num_big_transposes) + end + if best_score === nothing || score < best_score + best_score = score + res = order end - sort!(res; by=idx -> counts[idx] == 1, alg=Base.MergeSort) end return res end