Skip to content

Commit 94ae9b2

Browse files
authored
Use Distributed.pmap in ParallelTableTransform and ColwiseFeatureTransform (#288)
* Use 'Distributed.pmap' in 'ParallelTableTransform' * Use 'Distributed.pmap' in 'ColwiseFeatureTransform' * Apply suggestions * Update docs
1 parent 56e0add commit 94ae9b2

File tree

5 files changed

+17
-30
lines changed

5 files changed

+17
-30
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597"
99
CoDa = "5900dafe-f573-5c72-b367-76665857777b"
1010
ColumnSelectors = "9cc86067-7e36-4c61-b350-1ac9833d277f"
1111
DataScienceTraits = "6cb2f572-2d2b-4ba6-bdb3-e710fa044d6c"
12+
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
1213
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
1314
InverseFunctions = "3587e190-3f89-42d0-90ee-14403ec27112"
1415
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
@@ -18,7 +19,6 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1819
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
1920
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
2021
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"
21-
Transducers = "28d57a85-8fef-5791-bfe6-a80928e7c999"
2222
TransformsBase = "28dd2a49-a57a-4bfb-84ca-1a49db9b96b8"
2323
Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"
2424

@@ -28,6 +28,7 @@ CategoricalArrays = "0.10"
2828
CoDa = "1.2"
2929
ColumnSelectors = "0.1"
3030
DataScienceTraits = "0.3"
31+
Distributed = "1.9"
3132
Distributions = "0.25"
3233
InverseFunctions = "0.1"
3334
LinearAlgebra = "1.9"
@@ -37,7 +38,6 @@ Random = "1.9"
3738
Statistics = "1.9"
3839
StatsBase = "0.33, 0.34"
3940
Tables = "1.6"
40-
Transducers = "0.4"
4141
TransformsBase = "1.5"
4242
Unitful = "1.17"
4343
julia = "1.9"

docs/src/index.md

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,7 @@ missing transform, and contributions are very welcome.
2525
revertible when the individual transforms are revertible.
2626

2727
- Branches of a pipeline and colwise transforms are run in parallel
28-
using multiple threads with the awesome
29-
[Transducers.jl](https://github.com/JuliaFolds/Transducers.jl)
30-
framework.
28+
using multiple processes with the Distributed standard library.
3129

3230
- Pipelines can be reapplied to unseen "test" data using the same cache
3331
(e.g. constants) fitted with "training" data. For example, a `ZScore`

src/TableTransforms.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ using Unitful: AbstractQuantity, AffineQuantity, AffineUnits, Units
2323
using Distributions: ContinuousUnivariateDistribution, Normal
2424
using InverseFunctions: NoInverse, inverse as invfun
2525
using StatsBase: AbstractWeights, Weights, sample
26-
using Transducers: tcollect
26+
using Distributed: CachingPool, pmap, workers
2727
using NelderMead: optimise
2828

2929
import Distributions: quantile, cdf

src/transforms.jl

Lines changed: 9 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -184,8 +184,9 @@ function applyfeat(transform::ColwiseFeatureTransform, feat, prep)
184184
names = Tables.columnnames(cols)
185185
snames = transform.selector(names)
186186

187-
# function to transform a single column
188-
function colfunc(n)
187+
# transform each column in parallel
188+
pool = CachingPool(workers())
189+
vals = pmap(pool, names) do n
189190
x = Tables.getcolumn(cols, n)
190191
if n snames
191192
c = colcache(transform, x)
@@ -197,9 +198,6 @@ function applyfeat(transform::ColwiseFeatureTransform, feat, prep)
197198
(n => y), c
198199
end
199200

200-
# parallel map with multiple threads
201-
vals = tcollect(colfunc(n) for n in names)
202-
203201
# new table with transformed columns
204202
𝒯 = (; first.(vals)...)
205203
newfeat = 𝒯 |> Tables.materializer(feat)
@@ -218,18 +216,14 @@ function revertfeat(transform::ColwiseFeatureTransform, newfeat, fcache)
218216

219217
caches, snames = fcache
220218

221-
# function to transform a single column
222-
function colfunc(i)
223-
n = names[i]
224-
c = caches[i]
219+
# revert each column in parallel
220+
pool = CachingPool(workers())
221+
vals = pmap(pool, names, caches) do n, c
225222
y = Tables.getcolumn(cols, n)
226223
x = n snames ? colrevert(transform, y, c) : y
227224
n => x
228225
end
229226

230-
# parallel map with multiple threads
231-
vals = tcollect(colfunc(i) for i in 1:length(names))
232-
233227
# new table with transformed columns
234228
(; vals...) |> Tables.materializer(newfeat)
235229
end
@@ -244,18 +238,14 @@ function reapplyfeat(transform::ColwiseFeatureTransform, feat, fcache)
244238
# check that cache is valid
245239
_assert(length(names) == length(caches), "invalid caches for feat")
246240

247-
# function to transform a single column
248-
function colfunc(i)
249-
n = names[i]
250-
c = caches[i]
241+
# transform each column in parallel
242+
pool = CachingPool(workers())
243+
vals = pmap(pool, names, caches) do n, c
251244
x = Tables.getcolumn(cols, n)
252245
y = n snames ? colapply(transform, x, c) : x
253246
n => y
254247
end
255248

256-
# parallel map with multiple threads
257-
vals = tcollect(colfunc(i) for i in 1:length(names))
258-
259249
# new table with transformed columns
260250
(; vals...) |> Tables.materializer(feat)
261251
end

src/transforms/parallel.jl

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,8 @@ isrevertible(p::ParallelTableTransform) = any(isrevertible, p.transforms)
4141

4242
function apply(p::ParallelTableTransform, table)
4343
# apply transforms in parallel
44-
f(transform) = apply(transform, table)
45-
vals = tcollect(f(t) for t in p.transforms)
44+
pool = CachingPool(workers())
45+
vals = pmap(t -> apply(t, table), pool, p.transforms)
4646

4747
# retrieve tables and caches
4848
tables = first.(vals)
@@ -122,9 +122,8 @@ function reapply(p::ParallelTableTransform, table, cache)
122122
caches = cache[1]
123123

124124
# reapply transforms in parallel
125-
f(t, c) = reapply(t, table, c)
126-
itr = zip(p.transforms, caches)
127-
tables = tcollect(f(t, c) for (t, c) in itr)
125+
pool = CachingPool(workers())
126+
tables = pmap((t, c) -> reapply(t, table, c), pool, p.transforms, caches)
128127

129128
# features and metadata
130129
splits = divide.(tables)

0 commit comments

Comments
 (0)