We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent dc67b9d commit 3f07fe5Copy full SHA for 3f07fe5
test/ext/CUDA/cuda.jl
@@ -8,6 +8,7 @@ using Bijectors, CUDA, Distributions, Flux, LinearAlgebra, Test
8
@testset "rand with CUDA" begin
9
10
# Bijectors versions use dot for broadcasting, which causes issues with CUDA.
11
+ # https://github.com/TuringLang/Bijectors.jl/blob/6f0d383f73afd150a018b65a3ea4ac9306065d38/src/bijectors/planar_layer.jl#L65-L80
12
function Bijectors.get_u_hat(u::CuVector{T}, w::CuVector{T}) where {T<:Real}
13
wT_u = dot(w, u)
14
scale = (Bijectors.LogExpFunctions.log1pexp(-wT_u) - 1) / sum(abs2, w)
0 commit comments