Skip to content

Commit 65efe8e

Browse files
authored
Use float union type alias and remove CRC for now
1 parent b3e9682 commit 65efe8e

File tree

3 files changed

+4
-8
lines changed

3 files changed

+4
-8
lines changed

ext/NNlibCUDA/Project.toml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,12 @@ version = "0.1.11"
44

55
[deps]
66
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
7-
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
87
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
98
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
109
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1110
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
1211

1312
[compat]
14-
ChainRulesCore = "1"
1513
CUDA = "3.3.1"
1614
NNlib = "0.7.31"
1715
julia = "1.6"

ext/NNlibCUDA/src/NNlibCUDA.jl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,6 @@ module NNlibCUDA
33
using NNlib
44
using CUDA
55
using Random, Statistics
6-
using ChainRulesCore: NoTangent, ZeroTangent
7-
import ChainRulesCore: rrule
86

97
const IntOrIntTuple = Union{Integer, NTuple{N,<:Integer} where N}
108

ext/NNlibCUDA/src/cudnn/batchnorm.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ function cudnnBNForward!(y::DenseCuArray{T}, g::DenseCuArray{T}, b::DenseCuArray
4444
eps = T(1e-5),
4545
training = true,
4646
affine = true,
47-
track_stats = true) where T<:Union{Float32, Float64}
47+
track_stats = true) where T<:CUDNNFloat
4848
dims = _wsize(x)
4949
if eps < CUDNN_BN_MIN_EPSILON
5050
@warn "eps $eps is too small for CuDNN, setting to CUDNN_BN_MIN_EPSILON=$CUDNN_BN_MIN_EPSILON"
@@ -99,7 +99,7 @@ end
9999

100100
function ∇batchnorm(g::DenseCuArray{T}, b::DenseCuArray{T}, x::DenseCuArray{T, 2}, dy::DenseCuArray{T, 2},
101101
running_mean, running_var, momentum;
102-
kws...) where T<:Union{Float32, Float64}
102+
kws...) where T<:CUDNNFloat
103103
dg, db, dx = ∇batchnorm(g, b, reshape(x, 1, 1, size(x, 1), size(x, 2)), reshape(dy, 1, 1, size(dy, 1),
104104
size(dy, 2)), running_mean, running_var, momentum; kws...)
105105
(dg, db, dropdims(dx, dims = (1, 2)))
@@ -108,7 +108,7 @@ end
108108

109109
function ∇batchnorm(g::DenseCuArray{T}, b::DenseCuArray{T}, x::DenseCuArray{T}, dy::DenseCuArray{T},
110110
running_mean, running_var, momentum;
111-
affine=true, kws...) where T<:Union{Float32, Float64}
111+
affine=true, kws...) where T<:CUDNNFloat
112112
dg = similar(g)
113113
db = similar(b)
114114
dx = similar(x)
@@ -127,7 +127,7 @@ function cudnnBNBackward!(dg::DenseCuArray{T}, g::DenseCuArray{T}, db::DenseCuAr
127127
momentum; cache = nothing, eps = T(1e-5),
128128
alpha = T(1), beta = T(0),
129129
dalpha = T(1), dbeta = T(0), training = true,
130-
track_stats = true) where T<:Union{Float32, Float64}
130+
track_stats = true) where T<:CUDNNFloat
131131
if !track_stats
132132
running_mean = CU_NULL
133133
running_var = CU_NULL

0 commit comments

Comments
 (0)