Skip to content

Commit 40fec7c

Browse files
committed
add repmat & remove nnlib
1 parent d5e78ce commit 40fec7c

File tree

5 files changed

+48
-12
lines changed

5 files changed

+48
-12
lines changed

REQUIRE

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,2 @@
11
julia 0.6
22
StaticArrays
3-
NNlib

src/GPUArrays.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@ include("convolution.jl")
1818
include("testsuite/testsuite.jl")
1919
include("jlbackend.jl")
2020
include("random.jl")
21-
include("nnlib.jl")
2221

2322
export GPUArray, gpu_call, thread_blocks_heuristic, global_size, synchronize_threads
2423
export linear_index, @linearidx, @cartesianidx, convolution!, device, synchronize

src/base.jl

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,3 +92,44 @@ function _sub2ind(inds, L, ind, i::IT, I::IT...) where IT
9292
r1 = inds[1]
9393
_sub2ind(Base.tail(inds), L * r1, ind + (i - IT(1)) * L, I...)
9494
end
95+
96+
@inline Base.@propagate_inbounds getidx_2d1d(x::AbstractVector, i, j) = x[i]
97+
@inline Base.@propagate_inbounds getidx_2d1d(x::AbstractMatrix, i, j) = x[i, j]
98+
99+
function Base.repmat(a::GPUVecOrMat, m::Int, n::Int = 1)
100+
o, p = size(a, 1), size(a, 2)
101+
b = similar(a, o*m, p*n)
102+
args = (b, a, UInt32.((o, p, m, n))...)
103+
gpu_call(a, args, n) do state, b, a, o, p, m, n
104+
j = linear_index(state)
105+
j > n && return
106+
ui1 = UInt32(1)
107+
d = (j - ui1) * p + ui1
108+
@inbounds for i in ui1:m
109+
c = (i - ui1) * o + ui1
110+
for r in ui1:p
111+
for k in ui1:o
112+
b[k - ui1 + c, r - ui1 + d] = getidx_2d1d(a, k, r)
113+
end
114+
end
115+
end
116+
return
117+
end
118+
return b
119+
end
120+
121+
function Base.repmat(a::GPUVector, m::Int)
122+
o = length(a)
123+
b = similar(a, o*m)
124+
gpu_call(a, (b, a, UInt32(o), UInt32(m)), m) do state, b, a, o, m
125+
i = linear_index(state)
126+
i > m && return
127+
ui1 = UInt32(1)
128+
c = (i - ui1)*o + ui1
129+
@inbounds for i in ui1:o
130+
b[c + i - ui1] = a[i]
131+
end
132+
return
133+
end
134+
return b
135+
end

src/nnlib.jl

Lines changed: 0 additions & 10 deletions
This file was deleted.

src/testsuite/base.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,5 +133,12 @@ function run_base(Typ)
133133
against_base((a, b)-> map!(-, a, b), T, (10,), (10,))
134134
against_base((a, b, c, d)-> map!(*, a, b, c, d), T, (10,), (10,), (10,), (10,))
135135
end
136+
137+
@testset "repmat" begin
138+
against_base(a-> repmat(a, 5, 6), T, (10,))
139+
against_base(a-> repmat(a, 5), T, (10,))
140+
against_base(a-> repmat(a, 5), T, (5, 4))
141+
against_base(a-> repmat(a, 4, 3), T, (10, 15))
142+
end
136143
end
137144
end

0 commit comments

Comments
 (0)