Skip to content

Commit d884369

Browse files
author
Michael Abbott
committed
rm TensorCast
1 parent da64f8d commit d884369

File tree

3 files changed

+6
-45
lines changed

3 files changed

+6
-45
lines changed

Project.toml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "SliceMap"
22
uuid = "82cb661a-3f19-5665-9e27-df437c7e54c8"
33
authors = ["Michael Abbott"]
4-
version = "0.1.2"
4+
version = "0.2"
55

66
[deps]
77
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
@@ -18,7 +18,6 @@ ForwardDiff = "0.10"
1818
JuliennedArrays = "0.2"
1919
MacroTools = "0.4, 0.5"
2020
StaticArrays = "0.10, 0.11, 0.12"
21-
TensorCast = "0.1, 0.2"
2221
Tracker = "0.2"
2322
Zygote = "0.4"
2423
ZygoteRules = "0.2"

src/SliceMap.jl

Lines changed: 4 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ module SliceMap
33

44
export mapcols, MapCols, maprows, slicemap, tmapcols, ThreadMapCols
55

6-
using MacroTools, TensorCast, JuliennedArrays
6+
using MacroTools, JuliennedArrays
77

88
using Tracker
99
using Tracker: TrackedMatrix, track, @grad, data
@@ -81,12 +81,11 @@ e.g. if `dims=(2,4)` then `f` must map matrices to matrices.
8181
The gradient is for Zygote only.
8282
"""
8383
function slicemap(f::Function, A::AbstractArray{T,N}, args...; dims) where {T,N}
84-
code = ntuple(d -> d in dims ? (:) : (*), N)
85-
B = TensorCast.sliceview(A, code)
84+
code = ntuple(d -> d in dims ? True() : False(), N)
85+
B = JuliennedArrays.Slices(A, code...)
8686
C = [ f(slice, args...) for slice in B ]
87-
TensorCast.glue(C, code)
87+
JuliennedArrays.Align(C, code...)
8888
end
89-
# TODO switch to JuliennedArrays, then rm TensorCast dep
9089

9190
#========== Forward, Static ==========#
9291

@@ -153,23 +152,6 @@ end
153152

154153
#========== Gradients for Zygote ==========#
155154

156-
# @require Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
157-
158-
# @init @require Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" include("zygote.jl")
159-
# Now using ZygoteRules instead, mapcols etc above.
160-
161-
#= TensorCast =#
162-
# These could move there, TODO
163-
164-
@adjoint TensorCast.sliceview(A::AbstractArray, code::Tuple) =
165-
TensorCast.sliceview(A, code), Δ -> (TensorCast.glue(Δ, code), nothing)
166-
167-
@adjoint TensorCast.red_glue(A::AbstractArray, code::Tuple) =
168-
TensorCast.red_glue(A, code), Δ -> (TensorCast.sliceview(Δ, code), nothing)
169-
170-
@adjoint TensorCast.copy_glue(A::AbstractArray, code::Tuple) =
171-
TensorCast.copy_glue(A, code), Δ -> (TensorCast.sliceview(Δ, code), nothing)
172-
173155
#= JuliennedArrays =#
174156

175157
@adjoint JuliennedArrays.Slices(whole, along...) =

test/runtests.jl

Lines changed: 1 addition & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11

22
using SliceMap
33
using Test
4-
using ForwardDiff, Tracker, Zygote, TensorCast, JuliennedArrays
4+
using ForwardDiff, Tracker, Zygote, JuliennedArrays
55

66
Zygote.refresh()
77

@@ -35,10 +35,6 @@ Zygote.refresh()
3535
@test grad Zygote.gradient(m -> sum(sin, tmapcols(fun, m)), mat)[1]
3636
@test grad Zygote.gradient(m -> sum(sin, ThreadMapCols{3}(fun, m)), mat)[1]
3737

38-
tcm(mat) = @cast out[i,j] := fun(mat[:,j])[i]
39-
@test res tcm(mat)
40-
@test grad Zygote.gradient(m -> sum(sin, tcm(m)), mat)[1]
41-
4238
jcols(f,m) = Align(map(f, Slices(m, True(), False())), True(), False())
4339
@test res jcols(fun, mat)
4440
@test grad Zygote.gradient(m -> sum(sin, jcols(fun, m)), mat)[1]
@@ -61,10 +57,6 @@ end
6157
@test grad Zygote.gradient(m -> sum(sin, mapcols(fun, m)), mat)[1]
6258
@test grad Zygote.gradient(m -> sum(sin, MapCols{3}(fun, m)), mat)[1]
6359

64-
# tcm3(mat) = @cast out[_,j] := fun(mat[:,j]) # changed here too
65-
# @test res ≈ tcm3(mat)
66-
# @test grad ≈ Zygote.gradient(m -> sum(sin, tcm3(m)), mat)[1]
67-
6860
end
6961
@testset "columns -> matrix" begin
7062

@@ -83,10 +75,6 @@ end
8375
@test grad Zygote.gradient(m -> sum(sin, mapcols(fun, m)), mat)[1]
8476
@test grad Zygote.gradient(m -> sum(sin, MapCols{3}(fun, m)), mat)[1]
8577

86-
# tcm4(mat) = @cast out[i⊗i′,j] := fun(mat[:,j])[i,i′] i:3
87-
# @test res ≈ tcm4(mat)
88-
# @test grad ≈ Zygote.gradient(m -> sum(sin, tcm4(m)), mat)[1]
89-
9078
end
9179
@testset "columns w args" begin
9280

@@ -105,10 +93,6 @@ end
10593
@test grad Zygote.gradient(m -> sum(sin, mapcols(fun, m, 5)), mat)[1]
10694
@test grad Zygote.gradient(m -> sum(sin, MapCols{3}(fun, m, 5)), mat)[1]
10795

108-
# tcm5(mat) = @cast out[i,j] := fun(mat[:,j], 5)[i]
109-
# @test res ≈ tcm5(mat)
110-
# @test grad ≈ Zygote.gradient(m -> sum(sin, tcm5(m)), mat)[1]
111-
11296
end
11397
@testset "rows" begin
11498

@@ -122,10 +106,6 @@ end
122106
@test grad Tracker.gradient(m -> sum(sin, maprows(fun, m)), mat)[1]
123107
@test grad Zygote.gradient(m -> sum(sin, maprows(fun, m)), mat)[1]
124108

125-
# tcm2(mat) = @cast out[i,j] := fun(mat[i,:])[j]
126-
# @test res ≈ tcm2(mat)
127-
# @test grad ≈ Zygote.gradient(m -> sum(sin, tcm2(m)), mat)[1]
128-
129109
jrows(f,m) = Align(map(f, Slices(m, False(), True())), False(), True())
130110
@test res jrows(fun, mat)
131111
@test grad Zygote.gradient(m -> sum(sin, jrows(fun, m)), mat)[1]

0 commit comments

Comments
 (0)