Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 52 additions & 52 deletions ext/SparseArraysExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,8 @@ end
ptr
idx
val
qos_fill
qos_stop
qos_used
qos_alloc
prev_pos
end

Expand Down Expand Up @@ -127,14 +127,14 @@ function Finch.virtualize(ctx, ex, ::Type{<:SparseMatrixCSC{Tv,Ti}}, tag=:tns) w
$val = $tag.nzval
end,
)
qos_fill = freshen(ctx, tag, :_qos_fill)
qos_stop = freshen(ctx, tag, :_qos_stop)
qos_used = freshen(ctx, tag, :_qos_used)
qos_alloc = freshen(ctx, tag, :_qos_alloc)
prev_pos = freshen(ctx, tag, :_prev_pos)
shape = [
VirtualExtent(literal(1), value(m, Ti)), VirtualExtent(literal(1), value(n, Ti))
]
VirtualSparseMatrixCSC(
tag, Tv, Ti, shape, ptr, idx, val, qos_fill, qos_stop, prev_pos
tag, Tv, Ti, shape, ptr, idx, val, qos_used, qos_alloc, prev_pos
)
end

Expand All @@ -149,8 +149,8 @@ function distribute(
distribute_buffer(ctx, arr.ptr, arch, style),
distribute_buffer(ctx, arr.idx, arch, style),
distribute_buffer(ctx, arr.val, arch, style),
arr.qos_fill,
arr.qos_stop,
arr.qos_used,
arr.qos_alloc,
arr.prev_pos,
)
end
Expand All @@ -166,8 +166,8 @@ function Finch.declare!(ctx::AbstractCompiler, arr::VirtualSparseMatrixCSC, init
push_preamble!(
ctx,
quote
$(arr.qos_fill) = $(Tp(0))
$(arr.qos_stop) = $(Tp(0))
$(arr.qos_used) = $(Tp(0))
$(arr.qos_alloc) = $(Tp(0))
resize!($(arr.ptr), $pos_stop + 1)
fill_range!($(arr.ptr), $(Tp(0)), 1, $pos_stop + 1)
$(arr.ptr)[1] = $(Tp(1))
Expand All @@ -187,17 +187,17 @@ end
function Finch.freeze!(ctx::AbstractCompiler, arr::VirtualSparseMatrixCSC)
p = freshen(ctx, :p)
pos_stop = ctx(getstop(virtual_size(ctx, arr)[2]))
qos_stop = freshen(ctx, :qos_stop)
qos_alloc = freshen(ctx, :qos_alloc)
push_preamble!(
ctx,
quote
resize!($(arr.ptr), $pos_stop + 1)
for $p in 1:($pos_stop)
$(arr.ptr)[$p + 1] += $(arr.ptr)[$p]
end
$qos_stop = $(arr.ptr)[$pos_stop + 1] - 1
resize!($(arr.idx), $qos_stop)
resize!($(arr.val), $qos_stop)
$qos_alloc = $(arr.ptr)[$pos_stop + 1] - 1
resize!($(arr.idx), $qos_alloc)
resize!($(arr.val), $qos_alloc)
end,
)
return arr
Expand All @@ -206,19 +206,19 @@ end
function Finch.thaw!(ctx::AbstractCompiler, arr::VirtualSparseMatrixCSC)
p = freshen(ctx, :p)
pos_stop = ctx(getstop(virtual_size(ctx, arr)[2]))
qos_stop = freshen(ctx, :qos_stop)
qos_alloc = freshen(ctx, :qos_alloc)
push_preamble!(
ctx,
quote
$(arr.qos_fill) = $(arr.ptr)[$pos_stop + 1] - 1
$(arr.qos_stop) = $(arr.qos_fill)
$qos_stop = $(arr.qos_fill)
$(arr.qos_used) = $(arr.ptr)[$pos_stop + 1] - 1
$(arr.qos_alloc) = $(arr.qos_used)
$qos_alloc = $(arr.qos_used)
$(
if issafe(get_mode_flag(ctx))
quote
$(arr.prev_pos) =
Finch.scansearch(
$(arr.ptr), $(arr.qos_stop) + 1, 1, $pos_stop
$(arr.ptr), $(arr.qos_alloc) + 1, 1, $pos_stop
) - 1
end
end
Expand Down Expand Up @@ -350,12 +350,12 @@ function Finch.unfurl(
j = tns.j
Tp = arr.Ti
qos = freshen(ctx, tag, :_qos)
qos_fill = arr.qos_fill
qos_stop = arr.qos_stop
qos_used = arr.qos_used
qos_alloc = arr.qos_alloc
dirty = freshen(ctx, tag, :dirty)
Thunk(;
preamble = quote
$qos = $qos_fill + 1
$qos = $qos_used + 1
$(if issafe(get_mode_flag(ctx))
quote
$(arr.prev_pos) < $(ctx(j)) || throw(FinchProtocolError("SparseMatrixCSCs cannot be updated multiple times"))
Expand All @@ -365,10 +365,10 @@ function Finch.unfurl(
body = (ctx) -> Lookup(;
body=(ctx, idx) -> Thunk(;
preamble = quote
if $qos > $qos_stop
$qos_stop = max($qos_stop << 1, 1)
Finch.resize_if_smaller!($(arr.idx), $qos_stop)
Finch.resize_if_smaller!($(arr.val), $qos_stop)
if $qos > $qos_alloc
$qos_alloc = max($qos_alloc << 1, 1)
Finch.resize_if_smaller!($(arr.idx), $qos_alloc)
Finch.resize_if_smaller!($(arr.val), $qos_alloc)
end
$dirty = false
end,
Expand All @@ -387,8 +387,8 @@ function Finch.unfurl(
)
),
epilogue = quote
$(arr.ptr)[$(ctx(j)) + 1] += $qos - $qos_fill - 1
$qos_fill = $qos - 1
$(arr.ptr)[$(ctx(j)) + 1] += $qos - $qos_used - 1
$qos_used = $qos - 1
end,
)
end
Expand Down Expand Up @@ -429,8 +429,8 @@ end
shape
idx
val
qos_fill
qos_stop
qos_used
qos_alloc
end

function Finch.virtual_size(ctx::AbstractCompiler, arr::VirtualSparseVector)
Expand Down Expand Up @@ -460,9 +460,9 @@ function Finch.virtualize(ctx, ex, ::Type{<:SparseVector{Tv,Ti}}, tag=:tns) wher
$val = $tag.nzval
end,
)
qos_fill = freshen(ctx, tag, :_qos_fill)
qos_stop = freshen(ctx, tag, :_qos_stop)
VirtualSparseVector(tag, Tv, Ti, shape, idx, val, qos_fill, qos_stop)
qos_used = freshen(ctx, tag, :_qos_used)
qos_alloc = freshen(ctx, tag, :_qos_alloc)
VirtualSparseVector(tag, Tv, Ti, shape, idx, val, qos_used, qos_alloc)
end

function distribute(
Expand All @@ -475,8 +475,8 @@ function distribute(
arr.shape,
distribute_buffer(ctx, arr.idx, arch, style),
distribute_buffer(ctx, arr.val, arch, style),
arr.qos_fill,
arr.qos_stop,
arr.qos_used,
arr.qos_alloc,
)
end

Expand All @@ -490,36 +490,36 @@ function Finch.declare!(ctx::AbstractCompiler, arr::VirtualSparseVector, init)
push_preamble!(
ctx,
quote
$(arr.qos_fill) = $(Tp(0))
$(arr.qos_stop) = $(Tp(0))
$(arr.qos_used) = $(Tp(0))
$(arr.qos_alloc) = $(Tp(0))
end,
)
return arr
end

function Finch.freeze!(ctx::AbstractCompiler, arr::VirtualSparseVector)
p = freshen(ctx, :p)
qos_stop = freshen(ctx, :qos_stop)
qos_alloc = freshen(ctx, :qos_alloc)
push_preamble!(
ctx,
quote
$qos_stop = $(ctx(arr.qos_fill))
resize!($(arr.idx), $qos_stop)
resize!($(arr.val), $qos_stop)
$qos_alloc = $(ctx(arr.qos_used))
resize!($(arr.idx), $qos_alloc)
resize!($(arr.val), $qos_alloc)
end,
)
return arr
end

function Finch.thaw!(ctx::AbstractCompiler, arr::VirtualSparseVector)
p = freshen(ctx, :p)
qos_stop = freshen(ctx, :qos_stop)
qos_alloc = freshen(ctx, :qos_alloc)
push_preamble!(
ctx,
quote
$(arr.qos_fill) = length($(arr.idx))
$(arr.qos_stop) = $(arr.qos_fill)
$qos_stop = $(arr.qos_fill)
$(arr.qos_used) = length($(arr.idx))
$(arr.qos_alloc) = $(arr.qos_used)
$qos_alloc = $(arr.qos_used)
end,
)
return arr
Expand Down Expand Up @@ -593,23 +593,23 @@ function Finch.unfurl(
tag = arr.tag
Tp = arr.Ti
qos = freshen(ctx, tag, :_qos)
qos_fill = arr.qos_fill
qos_stop = arr.qos_stop
qos_used = arr.qos_used
qos_alloc = arr.qos_alloc
dirty = freshen(ctx, tag, :dirty)

Unfurled(;
arr=arr,
body=Thunk(;
preamble = quote
$qos = $qos_fill + 1
$qos = $qos_used + 1
end,
body = (ctx) -> Lookup(;
body=(ctx, idx) -> Thunk(;
preamble = quote
if $qos > $qos_stop
$qos_stop = max($qos_stop << 1, 1)
Finch.resize_if_smaller!($(arr.idx), $qos_stop)
Finch.resize_if_smaller!($(arr.val), $qos_stop)
if $qos > $qos_alloc
$qos_alloc = max($qos_alloc << 1, 1)
Finch.resize_if_smaller!($(arr.idx), $qos_alloc)
Finch.resize_if_smaller!($(arr.val), $qos_alloc)
end
$dirty = false
end,
Expand All @@ -623,7 +623,7 @@ function Finch.unfurl(
)
),
epilogue = quote
$qos_fill = $qos - 1
$qos_used = $qos - 1
end,
),
)
Expand Down
2 changes: 2 additions & 0 deletions src/Finch.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ export Dense, DenseLevel
export Element, ElementLevel
export AtomicElement, AtomicElementLevel
export Separate, SeparateLevel
export Shard, ShardLevel
export Mutex, MutexLevel
export Pattern, PatternLevel
export Scalar, SparseScalar, ShortCircuitScalar, SparseShortCircuitScalar
Expand Down Expand Up @@ -142,6 +143,7 @@ include("tensors/levels/dense_rle_levels.jl")
include("tensors/levels/element_levels.jl")
include("tensors/levels/atomic_element_levels.jl")
include("tensors/levels/separate_levels.jl")
include("tensors/levels/shard_levels.jl")
include("tensors/levels/mutex_levels.jl")
include("tensors/levels/pattern_levels.jl")
include("tensors/masks.jl")
Expand Down
61 changes: 50 additions & 11 deletions src/architecture.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,14 @@ abstract type AbstractVirtualTask end
Return the number of tasks on the device dev.
"""
function get_num_tasks end

"""
get_task_num(task::AbstractTask)

Return the task number of `task`.
"""
function get_task_num end

"""
get_device(task::AbstractTask)

Expand All @@ -61,6 +63,25 @@ Return the task which spawned `task`.
"""
function get_parent_task end

get_num_tasks(ctx::AbstractCompiler) = get_num_tasks(get_task(ctx))
get_num_tasks(task::AbstractTask) = get_num_tasks(get_device(task))
get_task_num(ctx::AbstractCompiler) = get_task_num(get_task(ctx))
get_device(ctx::AbstractCompiler) = get_device(get_task(ctx))
get_parent_task(ctx::AbstractCompiler) = get_parent_task(get_task(ctx))

function is_on_device(ctx::AbstractCompiler, dev)
res = false
task = get_task(ctx)
while task != nothing
if get_device(task) == dev
res = true
break
end
task = get_parent_task(task)
end
return res
end

"""
aquire_lock!(dev::AbstractDevice, val)

Expand Down Expand Up @@ -92,20 +113,35 @@ function make_lock end
"""
Serial()

A device that represents a serial CPU execution.
A Task that represents a serial CPU execution.
"""
struct Serial <: AbstractTask end
struct Serial <: AbstractDevice end
const serial = Serial()
get_device(::Serial) = CPU(1)
get_parent_task(::Serial) = nothing
get_task_num(::Serial) = 1
get_num_tasks(::Serial) = 1
struct VirtualSerial <: AbstractVirtualTask end
virtualize(ctx, ex, ::Type{Serial}) = VirtualSerial()
lower(ctx::AbstractCompiler, task::VirtualSerial, ::DefaultStyle) = :(Serial())
FinchNotation.finch_leaf(device::VirtualSerial) = virtual(device)
get_device(::VirtualSerial) = VirtualCPU(nothing, 1)
get_parent_task(::VirtualSerial) = nothing
get_task_num(::VirtualSerial) = literal(1)
get_num_tasks(::VirtualSerial) = literal(1)
Base.:(==)(::Serial, ::Serial) = true
Base.:(==)(::VirtualSerial, ::VirtualSerial) = true

"""
SerialTask()

A Task that represents a serial CPU execution.
"""
struct SerialTask <: AbstractDevice end
get_device(::SerialTask) = Serial()
get_parent_task(::SerialTask) = nothing
get_task_num(::SerialTask) = 1
struct VirtualSerialTask <: AbstractVirtualTask end
virtualize(ctx, ex, ::Type{SerialTask}) = VirtualSerialTask()
lower(ctx::AbstractCompiler, task::VirtualSerialTask, ::DefaultStyle) = :(SerialTask())
FinchNotation.finch_leaf(device::VirtualSerialTask) = virtual(device)
get_device(::VirtualSerialTask) = VirtualSerial()
get_parent_task(::VirtualSerialTask) = nothing
get_task_num(::VirtualSerialTask) = literal(1)

struct SerialMemory end
struct VirtualSerialMemory end
Expand Down Expand Up @@ -148,6 +184,8 @@ function lower(ctx::AbstractCompiler, device::VirtualCPU, ::DefaultStyle)
something(device.ex, :(CPU($(ctx(device.n)))))
end
get_num_tasks(::VirtualCPU) = literal(1)
Base.:(==)(::CPU, ::CPU) = true
Base.:(==)(::VirtualCPU, ::VirtualCPU) = true #This is not strictly true. A better approach would name devices, and give them parents so that we can be sure to parallelize through the processor hierarchy.

FinchNotation.finch_leaf(device::VirtualCPU) = virtual(device)

Expand Down Expand Up @@ -212,7 +250,7 @@ function transfer(device::CPULocalMemory, arr::AbstractArray)
CPULocalArray{A}(mem.device, [copy(arr) for _ in 1:(mem.device.n)])
end
function transfer(task::CPUThread, arr::CPULocalArray)
if get_device(task) === arr.device
if get_device(task) == arr.device
temp = arr.data[task.tid]
return temp
else
Expand All @@ -223,6 +261,7 @@ function transfer(dst::AbstractArray, arr::AbstractArray)
return arr
end


"""
transfer(device, arr)

Expand Down Expand Up @@ -484,8 +523,8 @@ for T in [
end
end

function virtual_parallel_region(f, ctx, ::Serial)
contain(f, ctx)
function virtual_parallel_region(f, ctx, ::VirtualSerial)
contain(f, ctx; task=VirtualSerialTask())
end

function virtual_parallel_region(f, ctx, device::VirtualCPU)
Expand Down
Loading
Loading