Skip to content

Commit 7576323

Browse files
Merge pull request #205 from tcarion/diskarrays
DiskArrays for `Variable`'s
2 parents 36d2d51 + fc4e099 commit 7576323

13 files changed

+89
-155
lines changed

Project.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ CFTime = "179af706-886a-5703-950a-314cd64e0468"
1010
CommonDataModel = "1fbeeb36-5f17-413c-809b-666fb144f157"
1111
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
1212
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
13+
DiskArrays = "3c3547ce-8d99-4f5e-a174-61eb10b00ae3"
1314
NetCDF_jll = "7243133f-43d8-5620-bbf4-c2c921802cf3"
1415
NetworkOptions = "ca575930-c2e3-43a9-ace4-1e988b2c1908"
1516
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
@@ -24,10 +25,10 @@ julia = "1.3"
2425

2526
[extras]
2627
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
28+
IntervalSets = "8197267c-284f-5f27-9208-e0e47529a953"
2729
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
2830
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
2931
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
30-
IntervalSets = "8197267c-284f-5f27-9208-e0e47529a953"
3132

3233
[targets]
3334
test = ["Dates", "Test", "Random", "Printf", "IntervalSets"]

src/NCDatasets.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,9 @@ import CommonDataModel: AbstractDataset, AbstractVariable,
3838
groupnames, group, defGroup,
3939
dimnames, dim, defDim,
4040
attribnames, attrib, defAttrib
41+
import DiskArrays
42+
import DiskArrays: readblock!, writeblock!, eachchunk, haschunks
43+
using DiskArrays: @implement_diskarray
4144

4245
function __init__()
4346
NetCDF_jll.is_available() && init_certificate_authority()
@@ -65,6 +68,8 @@ include("ncgen.jl")
6568
include("select.jl")
6669
include("precompile.jl")
6770

71+
@implement_diskarray NCDatasets.Variable
72+
6873
export CatArrays
6974
export CFTime
7075
export daysinmonth, daysinyear, yearmonthday, yearmonth, monthday

src/cfvariable.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -311,7 +311,7 @@ function _range_indices_dest(of,v,rest...)
311311
end
312312
range_indices_dest(ri...) = _range_indices_dest((),ri...)
313313

314-
function Base.getindex(v::Union{CFVariable,Variable,MFVariable,SubVariable},indices::Union{Int,Colon,AbstractRange{<:Integer},Vector{Int}}...)
314+
function Base.getindex(v::Union{MFVariable,SubVariable},indices::Union{Int,Colon,AbstractRange{<:Integer},Vector{Int}}...)
315315
@debug "transform vector of indices to ranges"
316316

317317
sz_source = size(v)

src/dataset.jl

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -489,20 +489,32 @@ function Base.write(dest::NCDataset, src::AbstractDataset;
489489
# end
490490
end
491491

492+
function _destindex(ind, dimname, dimlength, unlimdims)
493+
nind = _normalizeindex(dimlength, ind)
494+
if dimname in unlimdims
495+
nind[1]:dimlength
496+
else
497+
nind
498+
end
499+
end
500+
_maxrange(dimname, idimensions, dimlength) = haskey(idimensions, dimname) ? idimensions[dimname][end] : dimlength
501+
492502
# loop over variables
493503
for varname in include
494504
(varname exclude) && continue
495505
@debug "Writing variable $varname..."
496506

497507
cfvar = src[varname]
508+
cfsz = size(cfvar)
498509
dimension_names = dimnames(cfvar)
499510
var = cfvar.var
500511
# indices for subset
501512
index = ntuple(i -> torange(get(idimensions,dimension_names[i],:)),length(dimension_names))
513+
destindex = ntuple(i -> _destindex(index[i], dimension_names[i], _maxrange(dimension_names[i], idimensions, cfsz[i]), unlimited_dims), length(dimension_names))
502514

503515
destvar = defVar(dest, varname, eltype(var), dimension_names; attrib = attribs(cfvar))
504516
# copy data
505-
destvar.var[:] = cfvar.var[index...]
517+
destvar.var[destindex...] = cfvar.var[index...]
506518
end
507519

508520
# loop over all global attributes

src/subvariable.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ close(ds)
9191
```
9292
9393
"""
94-
Base.view(v::AbstractVariable,indices::Union{Int,Colon,AbstractVector{Int}}...) = SubVariable(v,indices...)
94+
Base.view(v::Union{CFVariable, DeferVariable, MFCFVariable},indices::Union{Int,Colon,AbstractVector{Int}}...) = SubVariable(v,indices...)
9595
Base.view(v::SubVariable,indices::CartesianIndex) = view(v,indices.I...)
9696
Base.view(v::SubVariable,indices::CartesianIndices) = view(v,indices.indices...)
9797

src/variable.jl

Lines changed: 46 additions & 130 deletions
Original file line numberDiff line numberDiff line change
@@ -345,85 +345,68 @@ end
345345
nomissing(a::AbstractArray,value) = a
346346
export nomissing
347347

348-
349-
function Base.getindex(v::Variable,indexes::Int...)
348+
# This method needs to be duplicated instead of using an Union. Otherwise a DiskArrays fallback is called instead which impacts performances
349+
# (see https://github.com/Alexander-Barth/NCDatasets.jl/pull/205#issuecomment-1589575041)
350+
function readblock!(v::Variable, aout, indexes::TI...) where TI <: Union{AbstractUnitRange,StepRange}
350351
datamode(v.ds)
351-
return nc_get_var1(eltype(v),v.ds.ncid,v.varid,[i-1 for i in indexes[ndims(v):-1:1]])
352+
_readblock!(v, aout, indexes...)
353+
return aout
352354
end
353355

354-
function Base.setindex!(v::Variable{T,N},data,indexes::Int...) where N where T
355-
@debug "$(@__LINE__)"
356-
datamode(v.ds)
357-
# use zero-based indexes and reversed order
358-
nc_put_var1(v.ds.ncid,v.varid,[i-1 for i in indexes[ndims(v):-1:1]],T(data))
359-
return data
360-
end
356+
_readblock!(v::Variable, aout, indexes::AbstractUnitRange...) = _read_data_from_nc!(v, aout, indexes...)
357+
_readblock!(v::Variable, aout, indexes::StepRange...) = _read_data_from_nc!(v, aout, indexes...)
361358

362-
function Base.getindex(v::Variable{T,N},indexes::Colon...) where {T,N}
363-
datamode(v.ds)
364-
data = Array{T,N}(undef,size(v))
365-
nc_get_var!(v.ds.ncid,v.varid,data)
359+
readblock!(v::Variable, aout) = _read_data_from_nc!(v::Variable, aout)
366360

367-
# special case for scalar NetCDF variable
368-
if N == 0
369-
return data[]
370-
else
371-
return data
372-
end
361+
function _read_data_from_nc!(v::Variable, aout, indexes::Int...)
362+
aout .= nc_get_var1(eltype(v),v.ds.ncid,v.varid,[i-1 for i in reverse(indexes)])
373363
end
374364

375-
function Base.setindex!(v::Variable{T,N},data::T,indexes::Colon...) where {T,N}
376-
@debug "setindex! colon $data"
377-
datamode(v.ds) # make sure that the file is in data mode
378-
tmp = fill(data,size(v))
379-
nc_put_var(v.ds.ncid,v.varid,tmp)
380-
return data
365+
function _read_data_from_nc!(v::Variable{T,N}, aout, indexes::TR...) where {T,N} where TR <: Union{StepRange{Int,Int},UnitRange{Int}}
366+
start,count,stride,jlshape = ncsub(indexes)
367+
nc_get_vars!(v.ds.ncid,v.varid,start,count,stride,aout)
381368
end
382369

383-
# union types cannot be used to avoid ambiguity
384-
for data_type = [Number, String, Char]
385-
@eval begin
386-
# call to v .= 123
387-
function Base.setindex!(v::Variable{T,N},data::$data_type) where {T,N}
388-
@debug "setindex! $data"
389-
datamode(v.ds) # make sure that the file is in data mode
390-
tmp = fill(convert(T,data),size(v))
391-
nc_put_var(v.ds.ncid,v.varid,tmp)
392-
return data
393-
end
394-
395-
Base.setindex!(v::Variable,data::$data_type,indexes::Colon...) = setindex!(v::Variable,data)
396-
397-
function Base.setindex!(v::Variable{T,N},data::$data_type,indexes::StepRange{Int,Int}...) where {T,N}
398-
datamode(v.ds) # make sure that the file is in data mode
399-
start,count,stride,jlshape = ncsub(indexes[1:ndims(v)])
400-
tmp = fill(convert(T,data),jlshape)
401-
nc_put_vars(v.ds.ncid,v.varid,start,count,stride,tmp)
402-
return data
403-
end
404-
end
370+
function _read_data_from_nc!(v::Variable{T,N}, aout, indexes::Union{Int,Colon,AbstractRange{<:Integer}}...) where {T,N}
371+
sz = size(v)
372+
start,count,stride = ncsub2(sz,indexes...)
373+
jlshape = _shape_after_slice(sz,indexes...)
374+
nc_get_vars!(v.ds.ncid,v.varid,start,count,stride,aout)
405375
end
406376

407-
function Base.setindex!(v::Variable{T,N},data::AbstractArray{T,N},indexes::Colon...) where {T,N}
408-
datamode(v.ds) # make sure that the file is in data mode
377+
_read_data_from_nc!(v::Variable, aout) = _read_data_from_nc!(v, aout, 1)
409378

410-
nc_put_var(v.ds.ncid,v.varid,data)
379+
function writeblock!(v::Variable, data, indexes::TI...) where TI <: Union{AbstractUnitRange,StepRange}
380+
datamode(v.ds)
381+
_write_data_to_nc(v, data, indexes...)
411382
return data
412383
end
413384

414-
function Base.setindex!(v::Variable{T,N},data::AbstractArray{T2,N},indexes::Colon...) where {T,T2,N}
415-
datamode(v.ds) # make sure that the file is in data mode
416-
tmp =
417-
if T <: Integer
418-
round.(T,data)
419-
else
420-
convert(Array{T,N},data)
421-
end
385+
function _write_data_to_nc(v::Variable{T,N},data,indexes::Int...) where {T,N}
386+
nc_put_var1(v.ds.ncid,v.varid,[i-1 for i in reverse(indexes)],T(data[1]))
387+
end
422388

423-
nc_put_var(v.ds.ncid,v.varid,tmp)
424-
return data
389+
_write_data_to_nc(v::Variable, data) = _write_data_to_nc(v, data, 1)
390+
391+
function _write_data_to_nc(v::Variable{T, N}, data, indexes::StepRange{Int,Int}...) where {T, N}
392+
start,count,stride,jlshape = ncsub(indexes)
393+
nc_put_vars(v.ds.ncid,v.varid,start,count,stride,T.(data))
394+
end
395+
396+
function _write_data_to_nc(v::Variable, data, indexes::Union{AbstractRange{<:Integer}}...)
397+
ind = prod(length.(indexes)) == 1 ? first.(indexes) : normalizeindexes(size(v),indexes)
398+
return _write_data_to_nc(v, data, ind...)
425399
end
426400

401+
getchunksize(v::Variable) = getchunksize(haschunks(v),v)
402+
getchunksize(::DiskArrays.Chunked, v::Variable) = chunking(v)[2]
403+
# getchunksize(::DiskArrays.Unchunked, v::Variable) = DiskArrays.estimate_chunksize(v)
404+
getchunksize(::DiskArrays.Unchunked, v::Variable) = size(v)
405+
eachchunk(v::CFVariable) = eachchunk(v.var)
406+
haschunks(v::CFVariable) = haschunks(v.var)
407+
eachchunk(v::Variable) = DiskArrays.GridChunks(v, Tuple(getchunksize(v)))
408+
haschunks(v::Variable) = (chunking(v)[1] == :contiguous ? DiskArrays.Unchunked() : DiskArrays.Chunked())
409+
427410
_normalizeindex(n,ind::Base.OneTo) = 1:1:ind.stop
428411
_normalizeindex(n,ind::Colon) = 1:1:n
429412
_normalizeindex(n,ind::Int) = ind:1:ind
@@ -477,72 +460,5 @@ end
477460
return start,count,stride
478461
end
479462

480-
function Base.getindex(v::Variable{T,N},indexes::TR...) where {T,N} where TR <: Union{StepRange{Int,Int},UnitRange{Int}}
481-
start,count,stride,jlshape = ncsub(indexes[1:N])
482-
data = Array{T,N}(undef,jlshape)
483-
484-
datamode(v.ds)
485-
nc_get_vars!(v.ds.ncid,v.varid,start,count,stride,data)
486-
return data
487-
end
488-
489-
function Base.setindex!(v::Variable{T,N},data::T,indexes::StepRange{Int,Int}...) where {T,N}
490-
datamode(v.ds) # make sure that the file is in data mode
491-
start,count,stride,jlshape = ncsub(indexes[1:ndims(v)])
492-
tmp = fill(data,jlshape)
493-
nc_put_vars(v.ds.ncid,v.varid,start,count,stride,tmp)
494-
return data
495-
end
496-
497-
function Base.setindex!(v::Variable{T,N},data::Array{T,N},indexes::StepRange{Int,Int}...) where {T,N}
498-
datamode(v.ds) # make sure that the file is in data mode
499-
start,count,stride,jlshape = ncsub(indexes[1:ndims(v)])
500-
nc_put_vars(v.ds.ncid,v.varid,start,count,stride,data)
501-
return data
502-
end
503-
504-
# data can be Array{T2,N} or BitArray{N}
505-
function Base.setindex!(v::Variable{T,N},data::AbstractArray,indexes::StepRange{Int,Int}...) where {T,N}
506-
datamode(v.ds) # make sure that the file is in data mode
507-
start,count,stride,jlshape = ncsub(indexes[1:ndims(v)])
508-
509-
tmp = convert(Array{T,ndims(data)},data)
510-
nc_put_vars(v.ds.ncid,v.varid,start,count,stride,tmp)
511-
512-
return data
513-
end
514-
515-
516-
517-
518-
function Base.getindex(v::Variable{T,N},indexes::Union{Int,Colon,AbstractRange{<:Integer}}...) where {T,N}
519-
sz = size(v)
520-
start,count,stride = ncsub2(sz,indexes...)
521-
jlshape = _shape_after_slice(sz,indexes...)
522-
data = Array{T}(undef,jlshape)
523-
524-
datamode(v.ds)
525-
nc_get_vars!(v.ds.ncid,v.varid,start,count,stride,data)
526-
527-
return data
528-
end
529-
530-
# NetCDF scalars indexed as []
531-
Base.getindex(v::Variable{T, 0}) where T = v[1]
532-
533-
534-
535-
function Base.setindex!(v::Variable,data,indexes::Union{Int,Colon,AbstractRange{<:Integer}}...)
536-
ind = normalizeindexes(size(v),indexes)
537-
538-
# make arrays out of scalars (arrays can have zero dimensions)
539-
if (ndims(data) == 0) && !(data isa AbstractArray)
540-
data = fill(data,length.(ind))
541-
end
542-
543-
return v[ind...] = data
544-
end
545-
546-
547-
Base.getindex(v::Union{MFVariable,DeferVariable,Variable},ci::CartesianIndices) = v[ci.indices...]
548-
Base.setindex!(v::Union{MFVariable,DeferVariable,Variable},data,ci::CartesianIndices) = setindex!(v,data,ci.indices...)
463+
Base.getindex(v::Union{MFVariable,DeferVariable},ci::CartesianIndices) = v[ci.indices...]
464+
Base.setindex!(v::Union{MFVariable,DeferVariable},data,ci::CartesianIndices) = setindex!(v,data,ci.indices...)

test/perf/generate_data.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ ncv1 = defVar(ds,"v1", UInt8, ("longitude", "latitude", "time"), fillvalue = UIn
2929
for n = 1:sz[3]
3030
@show n
3131
ncv1[:,:,n] = rand(1:100,sz[1],sz[2])
32-
ncv1[:,1,n] = missing
32+
ncv1[:,1,n] .= missing
3333
end
3434

3535
close(ds)

test/test_check_size.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,10 @@ defVar(ds, "w", Float64, ("x", "Time"))
1717

1818
for i in 1:10
1919
ds["Time"][i] = i
20-
ds["a"][:,i] = 1
21-
@test_throws NCDatasets.NetCDFError ds["u"][:,i] = collect(1:9)
22-
@test_throws NCDatasets.NetCDFError ds["v"][:,i] = collect(1:11)
23-
@test_throws NCDatasets.NetCDFError ds["w"][:,i] = reshape(collect(1:20), 10, 2)
20+
ds["a"][:,i] .= 1
21+
@test_throws DimensionMismatch ds["u"][:,i] = collect(1:9)
22+
@test_throws DimensionMismatch ds["v"][:,i] = collect(1:11)
23+
@test_throws DimensionMismatch ds["w"][:,i] = reshape(collect(1:20), 10, 2)
2424

2525
# ignore singleton dimension
2626
ds["w"][:,i] = reshape(collect(1:10), 1, 1, 10, 1)
@@ -29,11 +29,11 @@ end
2929
ds["w"][:,:] = ones(10,10)
3030

3131
# w should grow along the unlimited dimension
32-
ds["w"][:,:] = ones(10,15)
32+
ds["w"][:,1:15] = ones(10,15)
3333
@test size(ds["w"]) == (10,15)
3434

3535
# w cannot grow along a fixed dimension
36-
@test_throws NCDatasets.NetCDFError ds["w"][:,:] = ones(11,15)
36+
@test_throws DimensionMismatch ds["w"][:,:] = ones(11,15)
3737

3838
# NetCDF: Index exceeds dimension bound
3939
@test_throws NCDatasets.NetCDFError ds["u"][100,100]

test/test_corner_cases.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ b = dropdims([1.], dims=(1,))
4646
NCDataset(fname,"c") do ds
4747
time = defDim(ds,"time",Inf)
4848
v = defVar(ds,"temp",Float32,("time",))
49-
ds["temp"][1:1] = b
49+
ds["temp"][1] = b
5050
@test ds["temp"][1] == 1
5151
end
5252

test/test_scalar.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ for (T,data) in ((Float32,123.f0),
3232
end
3333

3434
NCDataset(filename,"r") do ds
35-
v2 = ds["scalar"][:]
35+
v2 = ds["scalar"][1]
3636
@test v2 == data
3737
end
3838
rm(filename)

0 commit comments

Comments
 (0)