@@ -8,11 +8,13 @@ This computation is interesting because it occurs in the inner
88loop of computing the Neural Tangent Kernel of a convolutional
99layer.
1010
11- def unsafe_from_integer {n} [Ix n] (i:Int) : n =
12- unsafe_from_ordinal _ $ unsafe_i_to_n i
11+ def unsafe_from_integer(i:Int) -> n given (n|Ix) =
12+ unsafe_from_ordinal $ unsafe_i_to_n i
1313
14- def conv_1d {d1 d2} (kernel: (Fin d1)=>(Fin d2)=>Float)
15- (size: Nat) : (Fin d1)=>(Fin d2)=>Float =
14+ def conv_1d(
15+ kernel: (Fin d1)=>(Fin d2)=>Float,
16+ size: Nat)
17+ -> (Fin d1)=>(Fin d2)=>Float given (d1, d2) =
1618 half_kernel_size = (f_to_i $ (n_to_f size) / 2.0)
1719 for i j. sum for k: (Fin size).
1820 i' = n_to_i $ ordinal i
@@ -22,14 +24,18 @@ def conv_1d {d1 d2} (kernel: (Fin d1)=>(Fin d2)=>Float)
2224 j'' = j' + k' - half_kernel_size
2325 if i'' < 0 || i'' >= (n_to_i d1) || j'' < 0 || j'' >= (n_to_i d2)
2426 then 0
25- else kernel.(unsafe_from_integer i'').(unsafe_from_integer j'')
26-
27- def conv {n c h w} (kernel: (Fin n)=>(Fin c)=>(Fin h)=>(Fin w)=>Float)
28- (size: Int) : (Fin n)=>(Fin c)=>(Fin h)=>(Fin w)=>Float =
29- for n' c'. conv_1d kernel.n'.c' (unsafe_i_to_n size)
30-
31- def conv_spec {n c h w} (kernel: (Fin n)=>(Fin c)=>(Fin h)=>(Fin w)=>Float)
32- (size: Int) : (Fin n)=>(Fin c)=>(Fin h)=>(Fin w)=>Float =
27+ else kernel[unsafe_from_integer i'', unsafe_from_integer j'']
28+
29+ def conv(
30+ kernel: (Fin n)=>(Fin c)=>(Fin h)=>(Fin w)=>Float,
31+ size: Int)
32+ -> (Fin n)=>(Fin c)=>(Fin h)=>(Fin w)=>Float given (n, c, h, w) =
33+ for n' c'. conv_1d(kernel[n', c'], unsafe_i_to_n(size))
34+
35+ def conv_spec(
36+ kernel: (Fin n)=>(Fin c)=>(Fin h)=>(Fin w)=>Float,
37+ size: Int)
38+ -> (Fin n)=>(Fin c)=>(Fin h)=>(Fin w)=>Float given (n, c, h, w) =
3339 if size == 3
3440 then conv kernel 3
3541 else conv kernel size
0 commit comments