Skip to content

Commit bc79b84

Browse files
committed
fixup! fixup! fixup! Add streaming API
1 parent b22567c commit bc79b84

File tree

7 files changed

+392
-153
lines changed

7 files changed

+392
-153
lines changed

Manifest.toml

Lines changed: 25 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
julia_version = "1.8.5"
44
manifest_format = "2.0"
5-
project_hash = "5333a6c200b6e6add81c46547527f66ddc0dc16c"
5+
project_hash = "1e12d6aa088ae431916872c11d09544380c7a130"
66

77
[[deps.Artifacts]]
88
uuid = "56f22d72-fd6d-98f1-02f0-08ddc0907c33"
@@ -12,9 +12,9 @@ uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"
1212

1313
[[deps.ChainRulesCore]]
1414
deps = ["Compat", "LinearAlgebra", "SparseArrays"]
15-
git-tree-sha1 = "b66b8f8e3db5d7835fb8cbe2589ffd1cd456e491"
15+
git-tree-sha1 = "575cd02e080939a33b6df6c5853d14924c08e35b"
1616
uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
17-
version = "1.17.0"
17+
version = "1.23.0"
1818

1919
[[deps.ChangesOfVariables]]
2020
deps = ["InverseFunctions", "LinearAlgebra", "Test"]
@@ -23,26 +23,26 @@ uuid = "9e997f8a-9a97-42d5-a9f1-ce6bfc15e2c0"
2323
version = "0.1.8"
2424

2525
[[deps.Compat]]
26-
deps = ["Dates", "LinearAlgebra", "UUIDs"]
27-
git-tree-sha1 = "8a62af3e248a8c4bad6b32cbbe663ae02275e32c"
26+
deps = ["Dates", "LinearAlgebra", "TOML", "UUIDs"]
27+
git-tree-sha1 = "c955881e3c981181362ae4088b35995446298b80"
2828
uuid = "34da2185-b29b-5c13-b0c7-acf172513d20"
29-
version = "4.10.0"
29+
version = "4.14.0"
3030

3131
[[deps.CompilerSupportLibraries_jll]]
3232
deps = ["Artifacts", "Libdl"]
3333
uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae"
3434
version = "1.0.1+0"
3535

3636
[[deps.DataAPI]]
37-
git-tree-sha1 = "8da84edb865b0b5b0100c0666a9bc9a0b71c553c"
37+
git-tree-sha1 = "abe83f3a2f1b857aac70ef8b269080af17764bbe"
3838
uuid = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a"
39-
version = "1.15.0"
39+
version = "1.16.0"
4040

4141
[[deps.DataStructures]]
4242
deps = ["Compat", "InteractiveUtils", "OrderedCollections"]
43-
git-tree-sha1 = "3dbd312d370723b6bb43ba9d02fc36abade4518d"
43+
git-tree-sha1 = "0f4b5d62a88d8f59003e43c25a8a90de9eb76317"
4444
uuid = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
45-
version = "0.18.15"
45+
version = "0.18.18"
4646

4747
[[deps.Dates]]
4848
deps = ["Printf"]
@@ -91,28 +91,28 @@ uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
9191

9292
[[deps.LogExpFunctions]]
9393
deps = ["ChainRulesCore", "ChangesOfVariables", "DocStringExtensions", "InverseFunctions", "IrrationalConstants", "LinearAlgebra"]
94-
git-tree-sha1 = "7d6dd4e9212aebaeed356de34ccf262a3cd415aa"
94+
git-tree-sha1 = "18144f3e9cbe9b15b070288eef858f71b291ce37"
9595
uuid = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
96-
version = "0.3.26"
96+
version = "0.3.27"
9797

9898
[[deps.Logging]]
9999
uuid = "56ddb016-857b-54e1-b83d-db4d58db5568"
100100

101101
[[deps.MacroTools]]
102102
deps = ["Markdown", "Random"]
103-
git-tree-sha1 = "9ee1618cbf5240e6d4e0371d6f24065083f60c48"
103+
git-tree-sha1 = "2fa9ee3e63fd3a4f7a9a4f4744a52f4856de82df"
104104
uuid = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
105-
version = "0.5.11"
105+
version = "0.5.13"
106106

107107
[[deps.Markdown]]
108108
deps = ["Base64"]
109109
uuid = "d6f4376e-aef5-505a-96c1-9c027394607a"
110110

111111
[[deps.MemPool]]
112-
deps = ["DataStructures", "Distributed", "Mmap", "Random", "Serialization", "Sockets"]
113-
git-tree-sha1 = "b9c1a032c3c1310a857c061ce487c632eaa1faa4"
112+
deps = ["DataStructures", "Distributed", "Mmap", "Random", "ScopedValues", "Serialization", "Sockets"]
113+
git-tree-sha1 = "60dd4ac427d39e0b3f15b193845324523ee71c03"
114114
uuid = "f9f48841-c794-520a-933b-121f7ba6ed94"
115-
version = "0.4.4"
115+
version = "0.4.6"
116116

117117
[[deps.Missings]]
118118
deps = ["DataAPI"]
@@ -133,9 +133,9 @@ uuid = "4536629a-c528-5b80-bd46-f80d51c5b363"
133133
version = "0.3.20+0"
134134

135135
[[deps.OrderedCollections]]
136-
git-tree-sha1 = "2e73fe17cac3c62ad1aebe70d44c963c3cfdc3e3"
136+
git-tree-sha1 = "dfdf5519f235516220579f949664f1bf44e741c5"
137137
uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
138-
version = "1.6.2"
138+
version = "1.6.3"
139139

140140
[[deps.PrecompileTools]]
141141
deps = ["Preferences"]
@@ -145,9 +145,9 @@ version = "1.2.0"
145145

146146
[[deps.Preferences]]
147147
deps = ["TOML"]
148-
git-tree-sha1 = "00805cd429dcb4870060ff49ef443486c262e38e"
148+
git-tree-sha1 = "9306f6085165d270f7e3db02af26a400d580f5c6"
149149
uuid = "21216c6a-2e73-6563-6e65-726566657250"
150-
version = "1.4.1"
150+
version = "1.4.3"
151151

152152
[[deps.Printf]]
153153
deps = ["Unicode"]
@@ -173,9 +173,9 @@ version = "0.7.0"
173173

174174
[[deps.ScopedValues]]
175175
deps = ["HashArrayMappedTries", "Logging"]
176-
git-tree-sha1 = "e3b5e4ccb1702db2ae9ac2a660d4b6b2a8595742"
176+
git-tree-sha1 = "c27d546a4749c81f70d1fabd604da6aa5054e3d2"
177177
uuid = "7e506255-f358-4e82-b7e4-beb19740aa63"
178-
version = "1.1.0"
178+
version = "1.2.0"
179179

180180
[[deps.Serialization]]
181181
uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
@@ -189,9 +189,9 @@ uuid = "6462fe0b-24de-5631-8697-dd941f90decc"
189189

190190
[[deps.SortingAlgorithms]]
191191
deps = ["DataStructures"]
192-
git-tree-sha1 = "c60ec5c62180f27efea3ba2908480f8055e17cee"
192+
git-tree-sha1 = "66e0a8e672a0bdfca2c3f5937efb8538b9ddc085"
193193
uuid = "a2af1166-a08f-5f64-846c-94a0d3cef48c"
194-
version = "1.1.1"
194+
version = "1.2.1"
195195

196196
[[deps.SparseArrays]]
197197
deps = ["LinearAlgebra", "Random"]

Project.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@ Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
88
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
99
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
1010
MemPool = "f9f48841-c794-520a-933b-121f7ba6ed94"
11-
Mmap = "a63ad114-7e13-5084-954f-fe012c677804"
1211
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
1312
Profile = "9abbd945-dff8-562f-b5e8-e1ebf5ef1b79"
1413
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"

src/Dagger.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,8 @@ include("utils/caching.jl")
4343
include("sch/Sch.jl"); using .Sch
4444

4545
# Streaming
46+
include("stream-buffers.jl")
47+
include("stream-fetchers.jl")
4648
include("stream.jl")
4749

4850
# Array computations

src/eager_thunk.jl

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -67,17 +67,7 @@ function Base.fetch(t::EagerThunk; raw=false)
6767
if !isdefined(t, :thunk_ref)
6868
throw(ConcurrencyViolationError("Cannot `fetch` an unlaunched `EagerThunk`"))
6969
end
70-
stream = task_to_stream(t.uid)
71-
if stream isa Stream
72-
add_waiters!(stream, [0])
73-
end
74-
try
75-
return fetch(t.future; raw)
76-
finally
77-
if stream isa Stream
78-
remove_waiters!(stream, [0])
79-
end
80-
end
70+
return fetch(t.future; raw)
8171
end
8272
function Base.show(io::IO, t::EagerThunk)
8373
status = if isdefined(t, :thunk_ref)

src/stream-buffers.jl

Lines changed: 204 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,204 @@
1+
using Mmap
2+
3+
"""
4+
A buffer that drops all elements put into it. Only to be used as the output
5+
buffer for a task - will throw if attached as an input.
6+
"""
7+
struct DropBuffer{T} end
8+
DropBuffer{T}(_) where T = DropBuffer{T}()
9+
Base.isempty(::DropBuffer) = true
10+
isfull(::DropBuffer) = false
11+
Base.put!(::DropBuffer, _) = nothing
12+
Base.take!(::DropBuffer) = error("Cannot `take!` from a DropBuffer")
13+
14+
"A process-local buffer backed by a `Channel{T}`."
15+
struct ChannelBuffer{T}
16+
channel::Channel{T}
17+
len::Int
18+
count::Threads.Atomic{Int}
19+
ChannelBuffer{T}(len::Int=1024) where T =
20+
new{T}(Channel{T}(len), len, Threads.Atomic{Int}(0))
21+
end
22+
Base.isempty(cb::ChannelBuffer) = isempty(cb.channel)
23+
isfull(cb::ChannelBuffer) = cb.count[] == cb.len
24+
function Base.put!(cb::ChannelBuffer{T}, x) where T
25+
put!(cb.channel, convert(T, x))
26+
Threads.atomic_add!(cb.count, 1)
27+
end
28+
function Base.take!(cb::ChannelBuffer)
29+
take!(cb.channel)
30+
Threads.atomic_sub!(cb.count, 1)
31+
end
32+
33+
"A cross-worker buffer backed by a `RemoteChannel{T}`."
34+
struct RemoteChannelBuffer{T}
35+
channel::RemoteChannel{Channel{T}}
36+
len::Int
37+
count::Threads.Atomic{Int}
38+
RemoteChannelBuffer{T}(len::Int=1024) where T =
39+
new{T}(RemoteChannel(()->Channel{T}(len)), len, Threads.Atomic{Int}(0))
40+
end
41+
Base.isempty(cb::RemoteChannelBuffer) = isempty(cb.channel)
42+
isfull(cb::RemoteChannelBuffer) = cb.count[] == cb.len
43+
function Base.put!(cb::RemoteChannelBuffer{T}, x) where T
44+
put!(cb.channel, convert(T, x))
45+
Threads.atomic_add!(cb.count, 1)
46+
end
47+
function Base.take!(cb::RemoteChannelBuffer)
48+
take!(cb.channel)
49+
Threads.atomic_sub!(cb.count, 1)
50+
end
51+
52+
"A process-local ring buffer."
53+
mutable struct ProcessRingBuffer{T}
54+
read_idx::Int
55+
write_idx::Int
56+
@atomic count::Int
57+
buffer::Vector{T}
58+
function ProcessRingBuffer{T}(len::Int=1024) where T
59+
buffer = Vector{T}(undef, len)
60+
return new{T}(1, 1, 0, buffer)
61+
end
62+
end
63+
Base.isempty(rb::ProcessRingBuffer) = (@atomic rb.count) == 0
64+
isfull(rb::ProcessRingBuffer) = (@atomic rb.count) == length(rb.buffer)
65+
function Base.put!(rb::ProcessRingBuffer{T}, x) where T
66+
len = length(rb.buffer)
67+
while (@atomic rb.count) == len
68+
yield()
69+
end
70+
to_write_idx = mod1(rb.write_idx, len)
71+
rb.buffer[to_write_idx] = convert(T, x)
72+
rb.write_idx += 1
73+
@atomic rb.count += 1
74+
end
75+
function Base.take!(rb::ProcessRingBuffer)
76+
while (@atomic rb.count) == 0
77+
yield()
78+
end
79+
to_read_idx = rb.read_idx
80+
rb.read_idx += 1
81+
@atomic rb.count -= 1
82+
to_read_idx = mod1(to_read_idx, length(rb.buffer))
83+
return rb.buffer[to_read_idx]
84+
end
85+
86+
#= TODO
87+
"A server-local ring buffer backed by shared-memory."
88+
mutable struct ServerRingBuffer{T}
89+
read_idx::Int
90+
write_idx::Int
91+
@atomic count::Int
92+
buffer::Vector{T}
93+
function ServerRingBuffer{T}(len::Int=1024) where T
94+
buffer = Vector{T}(undef, len)
95+
return new{T}(1, 1, 0, buffer)
96+
end
97+
end
98+
Base.isempty(rb::ServerRingBuffer) = (@atomic rb.count) == 0
99+
function Base.put!(rb::ServerRingBuffer{T}, x) where T
100+
len = length(rb.buffer)
101+
while (@atomic rb.count) == len
102+
yield()
103+
end
104+
to_write_idx = mod1(rb.write_idx, len)
105+
rb.buffer[to_write_idx] = convert(T, x)
106+
rb.write_idx += 1
107+
@atomic rb.count += 1
108+
end
109+
function Base.take!(rb::ServerRingBuffer)
110+
while (@atomic rb.count) == 0
111+
yield()
112+
end
113+
to_read_idx = rb.read_idx
114+
rb.read_idx += 1
115+
@atomic rb.count -= 1
116+
to_read_idx = mod1(to_read_idx, length(rb.buffer))
117+
return rb.buffer[to_read_idx]
118+
end
119+
=#
120+
121+
#=
122+
"A TCP-based ring buffer."
123+
mutable struct TCPRingBuffer{T}
124+
read_idx::Int
125+
write_idx::Int
126+
@atomic count::Int
127+
buffer::Vector{T}
128+
function TCPRingBuffer{T}(len::Int=1024) where T
129+
buffer = Vector{T}(undef, len)
130+
return new{T}(1, 1, 0, buffer)
131+
end
132+
end
133+
Base.isempty(rb::TCPRingBuffer) = (@atomic rb.count) == 0
134+
function Base.put!(rb::TCPRingBuffer{T}, x) where T
135+
len = length(rb.buffer)
136+
while (@atomic rb.count) == len
137+
yield()
138+
end
139+
to_write_idx = mod1(rb.write_idx, len)
140+
rb.buffer[to_write_idx] = convert(T, x)
141+
rb.write_idx += 1
142+
@atomic rb.count += 1
143+
end
144+
function Base.take!(rb::TCPRingBuffer)
145+
while (@atomic rb.count) == 0
146+
yield()
147+
end
148+
to_read_idx = rb.read_idx
149+
rb.read_idx += 1
150+
@atomic rb.count -= 1
151+
to_read_idx = mod1(to_read_idx, length(rb.buffer))
152+
return rb.buffer[to_read_idx]
153+
end
154+
=#
155+
156+
#=
157+
"""
158+
A flexible puller which switches to the most efficient buffer type based
159+
on the sender and receiver locations.
160+
"""
161+
mutable struct UniBuffer{T}
162+
buffer::Union{ProcessRingBuffer{T}, Nothing}
163+
end
164+
function initialize_stream_buffer!(::Type{UniBuffer{T}}, T, send_proc, recv_proc, buffer_amount) where T
165+
if buffer_amount == 0
166+
error("Return NullBuffer")
167+
end
168+
send_osproc = get_parent(send_proc)
169+
recv_osproc = get_parent(recv_proc)
170+
if send_osproc.pid == recv_osproc.pid
171+
inner = RingBuffer{T}(buffer_amount)
172+
elseif system_uuid(send_osproc.pid) == system_uuid(recv_osproc.pid)
173+
inner = ProcessBuffer{T}(buffer_amount)
174+
else
175+
inner = RemoteBuffer{T}(buffer_amount)
176+
end
177+
return UniBuffer{T}(buffer_amount)
178+
end
179+
180+
struct LocalPuller{T,B}
181+
buffer::B{T}
182+
id::UInt
183+
function LocalPuller{T,B}(id::UInt, buffer_amount::Integer) where {T,B}
184+
buffer = initialize_stream_buffer!(B, T, buffer_amount)
185+
return new{T,B}(buffer, id)
186+
end
187+
end
188+
function Base.take!(pull::LocalPuller{T,B}) where {T,B}
189+
if pull.buffer === nothing
190+
pull.buffer =
191+
error("Return NullBuffer")
192+
end
193+
value = take!(pull.buffer)
194+
end
195+
function initialize_input_stream!(stream::Stream{T,B}, id::UInt, send_proc::Processor, recv_proc::Processor, buffer_amount::Integer) where {T,B}
196+
local_buffer = remotecall_fetch(stream.ref.handle.owner, stream.ref.handle, id) do ref, id
197+
local_buffer, remote_buffer = initialize_stream_buffer!(B, T, send_proc, recv_proc, buffer_amount)
198+
ref.buffers[id] = remote_buffer
199+
return local_buffer
200+
end
201+
stream.buffer = local_buffer
202+
return stream
203+
end
204+
=#

0 commit comments

Comments
 (0)