Skip to content

Commit c0fa128

Browse files
committed
Update convolution benchmark to new syntax.
1 parent bbec9ad commit c0fa128

File tree

2 files changed

+27
-12
lines changed

2 files changed

+27
-12
lines changed

benchmarks/conv.dx

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,13 @@ This computation is interesting because it occurs in the inner
88
loop of computing the Neural Tangent Kernel of a convolutional
99
layer.
1010

11-
def unsafe_from_integer {n} [Ix n] (i:Int) : n =
12-
unsafe_from_ordinal _ $ unsafe_i_to_n i
11+
def unsafe_from_integer(i:Int) -> n given (n|Ix) =
12+
unsafe_from_ordinal $ unsafe_i_to_n i
1313

14-
def conv_1d {d1 d2} (kernel: (Fin d1)=>(Fin d2)=>Float)
15-
(size: Nat) : (Fin d1)=>(Fin d2)=>Float =
14+
def conv_1d(
15+
kernel: (Fin d1)=>(Fin d2)=>Float,
16+
size: Nat)
17+
-> (Fin d1)=>(Fin d2)=>Float given (d1, d2) =
1618
half_kernel_size = (f_to_i $ (n_to_f size) / 2.0)
1719
for i j. sum for k: (Fin size).
1820
i' = n_to_i $ ordinal i
@@ -22,14 +24,18 @@ def conv_1d {d1 d2} (kernel: (Fin d1)=>(Fin d2)=>Float)
2224
j'' = j' + k' - half_kernel_size
2325
if i'' < 0 || i'' >= (n_to_i d1) || j'' < 0 || j'' >= (n_to_i d2)
2426
then 0
25-
else kernel.(unsafe_from_integer i'').(unsafe_from_integer j'')
26-
27-
def conv {n c h w} (kernel: (Fin n)=>(Fin c)=>(Fin h)=>(Fin w)=>Float)
28-
(size: Int) : (Fin n)=>(Fin c)=>(Fin h)=>(Fin w)=>Float =
29-
for n' c'. conv_1d kernel.n'.c' (unsafe_i_to_n size)
30-
31-
def conv_spec {n c h w} (kernel: (Fin n)=>(Fin c)=>(Fin h)=>(Fin w)=>Float)
32-
(size: Int) : (Fin n)=>(Fin c)=>(Fin h)=>(Fin w)=>Float =
27+
else kernel[unsafe_from_integer i'', unsafe_from_integer j'']
28+
29+
def conv(
30+
kernel: (Fin n)=>(Fin c)=>(Fin h)=>(Fin w)=>Float,
31+
size: Int)
32+
-> (Fin n)=>(Fin c)=>(Fin h)=>(Fin w)=>Float given (n, c, h, w) =
33+
for n' c'. conv_1d(kernel[n', c'], unsafe_i_to_n(size))
34+
35+
def conv_spec(
36+
kernel: (Fin n)=>(Fin c)=>(Fin h)=>(Fin w)=>Float,
37+
size: Int)
38+
-> (Fin n)=>(Fin c)=>(Fin h)=>(Fin w)=>Float given (n, c, h, w) =
3339
if size == 3
3440
then conv kernel 3
3541
else conv kernel size

lib/prelude.dx

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -449,6 +449,15 @@ instance Ix((a, b, c)) given (a|Ix, b|Ix, c|Ix)
449449
(i, (j, k)) = unsafe_from_ordinal o
450450
(i, j, k)
451451

452+
instance Ix((a, b, c, d)) given (a|Ix, b|Ix, c|Ix, d|Ix)
453+
def size'() = size a * size b * size c * size d
454+
def ordinal(tup) =
455+
(i, j, k, m) = tup
456+
ordinal((i,(j,(k,m))))
457+
def unsafe_from_ordinal(o) =
458+
(i, (j, (k, m))) = unsafe_from_ordinal o
459+
(i, j, k, m)
460+
452461
'## Vector spaces
453462

454463
interface VSpace(a|Add|Sub)

0 commit comments

Comments
 (0)