Skip to content

Commit d760a1f

Browse files
authored
Add DefaultArrayInterface methods (#11)
1 parent bb982a0 commit d760a1f

File tree

7 files changed

+106
-44
lines changed

7 files changed

+106
-44
lines changed

.pre-commit-config.yaml

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@ repos:
77
- id: check-yaml
88
- id: end-of-file-fixer
99
exclude_types: [markdown] # incompatible with Literate.jl
10-
- repo: https://github.com/qiaojunfeng/pre-commit-julia-format
11-
rev: v0.2.0
10+
11+
- repo: "https://github.com/domluna/JuliaFormatter.jl"
12+
rev: v1.0.62
1213
hooks:
13-
- id: julia-format
14+
- id: "julia-formatter"

Project.toml

Lines changed: 1 addition & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "DerivableInterfaces"
22
uuid = "6c5e35bf-e59e-4898-b73c-732dcc4ba65f"
33
authors = ["ITensor developers <[email protected]> and contributors"]
4-
version = "0.3.7"
4+
version = "0.3.8"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
@@ -14,22 +14,9 @@ TypeParameterAccessors = "7e5a90cf-f82e-492e-a09b-e3e26432c138"
1414

1515
[compat]
1616
Adapt = "4.1.1"
17-
Aqua = "0.8.9"
1817
ArrayLayouts = "1.11.0"
1918
ExproniconLite = "0.10.13"
2019
LinearAlgebra = "1.10"
2120
MLStyle = "0.4.17"
2221
MapBroadcast = "0.1.5"
23-
SafeTestsets = "0.1"
24-
Suppressor = "0.2"
25-
Test = "1.10"
2622
julia = "1.10"
27-
28-
[extras]
29-
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
30-
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
31-
Suppressor = "fd094767-a336-5f1f-9728-57cf17d0bbfb"
32-
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
33-
34-
[targets]
35-
test = ["Aqua", "Test", "Suppressor", "SafeTestsets"]

src/defaultarrayinterface.jl

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,27 @@ function interface(a::Type{<:AbstractArray})
66
parenttype(a) === a && return DefaultArrayInterface()
77
return interface(parenttype(a))
88
end
9+
10+
@interface ::DefaultArrayInterface function Base.getindex(
11+
a::AbstractArray{<:Any,N}, I::Vararg{Int,N}
12+
) where {N}
13+
return Base.getindex(a, I...)
14+
end
15+
16+
@interface ::DefaultArrayInterface function Base.setindex!(
17+
a::AbstractArray{<:Any,N}, value, I::Vararg{Int,N}
18+
) where {N}
19+
return Base.setindex!(a, value, I...)
20+
end
21+
22+
@interface ::DefaultArrayInterface function Base.map!(
23+
f, a_dest::AbstractArray, a_srcs::AbstractArray...
24+
)
25+
return Base.map!(f, a_dest, a_srcs...)
26+
end
27+
28+
@interface ::DefaultArrayInterface function Base.mapreduce(
29+
f, op, as::AbstractArray...; kwargs...
30+
)
31+
return Base.mapreduce(f, op, as...; kwargs...)
32+
end

src/interface_macro.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -80,10 +80,11 @@ DerivableInterfaces.call(SparseArrayInterface(), Base.setindex!, a, value, I...)
8080
```
8181
=#
8282
function interface_setref(interface::Union{Symbol,Expr}, func::Expr)
83-
func = @match func begin
84-
:($a[$(I...)] = $value) => :(Base.setindex!($a, $value, $(I...)))
83+
return @match func begin
84+
:($a[$(I...)] = $value) => Expr(
85+
:block, interface_call(interface, :(Base.setindex!($a, $value, $(I...)))), :($value)
86+
)
8587
end
86-
return interface_call(interface, func)
8788
end
8889

8990
#=

test/Project.toml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,9 @@ MapBroadcast = "ebd9b9da-f48d-417c-9660-449667d60261"
77
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
88
Suppressor = "fd094767-a336-5f1f-9728-57cf17d0bbfb"
99
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
10+
11+
[compat]
12+
Aqua = "0.8.9"
13+
SafeTestsets = "0.1"
14+
Suppressor = "0.2"
15+
Test = "1.10"
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
using Test: @testset, @test, @inferred
2+
using DerivableInterfaces: @interface, DefaultArrayInterface
3+
4+
# function wrappers to test type-stability
5+
_getindex(A, i...) = @interface DefaultArrayInterface() A[i...]
6+
_setindex!(A, v, i...) = @interface DefaultArrayInterface() A[i...] = v
7+
_map!(args...) = @interface DefaultArrayInterface() map!(args...)
8+
function _mapreduce(args...; kwargs...)
9+
@interface DefaultArrayInterface() mapreduce(args...; kwargs...)
10+
end
11+
12+
@testset "indexing" begin
13+
for (A, i) in ((zeros(2), 2), (zeros(2, 2), (2, 1)), (zeros(1, 2, 3), (1, 2, 3)))
14+
a = @inferred _getindex(A, i...)
15+
@test a == A[i...]
16+
v = 1.1
17+
A′ = @inferred _setindex!(A, v, i...)
18+
@test A′ == (A[i...] = v)
19+
end
20+
end
21+
22+
@testset "map!" begin
23+
A = zeros(3)
24+
a = @inferred _map!(Returns(2), copy(A), A)
25+
@test a == map!(Returns(2), copy(A), A)
26+
end
27+
28+
@testset "mapreduce" begin
29+
A = zeros(3)
30+
a = @inferred _mapreduce(Returns(2), +, A)
31+
@test a == mapreduce(Returns(2), +, A)
32+
end

test/runtests.jl

Lines changed: 35 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,31 @@
11
using SafeTestsets: @safetestset
2-
using Suppressor: @suppress
3-
using Test: @testset
2+
using Suppressor: Suppressor
43

5-
@testset "DerivableInterfaces.jl tests" begin
6-
# check for filtered groups
7-
# either via `--group=ALL` or through ENV["GROUP"]
8-
pat = r"(?:--group=)(\w+)"
9-
arg_id = findfirst(contains(pat), ARGS)
10-
GROUP = uppercase(
11-
if isnothing(arg_id)
12-
get(ENV, "GROUP", "ALL")
13-
else
14-
only(match(pat, ARGS[arg_id]).captures)
15-
end,
16-
)
4+
# check for filtered groups
5+
# either via `--group=ALL` or through ENV["GROUP"]
6+
const pat = r"(?:--group=)(\w+)"
7+
arg_id = findfirst(contains(pat), ARGS)
8+
const GROUP = uppercase(
9+
if isnothing(arg_id)
10+
get(ENV, "GROUP", "ALL")
11+
else
12+
only(match(pat, ARGS[arg_id]).captures)
13+
end,
14+
)
1715

18-
function istestfile(filename)
19-
return isfile(filename) &&
20-
endswith(filename, ".jl") &&
21-
startswith(basename(filename), "test")
22-
end
16+
"match files of the form `test_*.jl`, but exclude `*setup*.jl`"
17+
istestfile(fn) =
18+
endswith(fn, ".jl") && startswith(basename(fn), "test_") && !contains(fn, "setup")
19+
"match files of the form `*.jl`, but exclude `*_notest.jl` and `*setup*.jl`"
20+
isexamplefile(fn) =
21+
endswith(fn, ".jl") && !endswith(fn, "_notest.jl") && !contains(fn, "setup")
2322

23+
@time begin
2424
# tests in groups based on folder structure
2525
for testgroup in filter(isdir, readdir(@__DIR__))
2626
if GROUP == "ALL" || GROUP == uppercase(testgroup)
2727
for file in filter(istestfile, readdir(joinpath(@__DIR__, testgroup); join=true))
28-
@eval @safetestset $file begin
28+
@eval @safetestset $(last(splitdir(file))) begin
2929
include($file)
3030
end
3131
end
@@ -34,17 +34,28 @@ using Test: @testset
3434

3535
# single files in top folder
3636
for file in filter(istestfile, readdir(@__DIR__))
37-
(file == basename(@__FILE__)) && continue
37+
(file == basename(@__FILE__)) && continue # exclude this file to avoid infinite recursion
3838
@eval @safetestset $file begin
3939
include($file)
4040
end
4141
end
4242

4343
# test examples
4444
examplepath = joinpath(@__DIR__, "..", "examples")
45-
for file in filter(endswith(".jl"), readdir(examplepath; join=true))
46-
@suppress @eval @safetestset $file begin
47-
include($file)
45+
for (root, _, files) in walkdir(examplepath)
46+
contains(chopprefix(root, @__DIR__), "setup") && continue
47+
for file in filter(isexamplefile, files)
48+
filename = joinpath(root, file)
49+
@eval begin
50+
@safetestset $file begin
51+
$(Expr(
52+
:macrocall,
53+
GlobalRef(Suppressor, Symbol("@suppress")),
54+
LineNumberNode(@__LINE__, @__FILE__),
55+
:(include($filename)),
56+
))
57+
end
58+
end
4859
end
4960
end
5061
end

0 commit comments

Comments
 (0)