File tree Expand file tree Collapse file tree 2 files changed +32
-1
lines changed Expand file tree Collapse file tree 2 files changed +32
-1
lines changed Original file line number Diff line number Diff line change @@ -2576,6 +2576,7 @@ def tile(
25762576 body (FullTileIx(n, tile_size, tile_ix'))
25772577 body (CodaIx(n, coda_offset, coda_size))
25782578
2579+ @noinline
25792580def tiled_matmul(
25802581 x: l=>m=>Float,
25812582 y: m=>n=>Float
@@ -2601,9 +2602,18 @@ def (**)(
26012602 x: l=>m=>Float,
26022603 y: m=>n=>Float
26032604 ) -> l=>n=>Float given (l|Ix, m|Ix, n|Ix) =
2604- -- TODO(https://github.com/google-research/dex-lang/issues/1212) Replace with tiled_matmul.
26052605 tiled_matmul(x, y)
26062606
2607+ def matmul_linearization(
2608+ x: l=>m=>Float,
2609+ y: m=>n=>Float
2610+ ) -> _ given (l|Ix, m|Ix, n|Ix) =
2611+ def lin(xt: l=>m=>Float, yt: m=>n=>Float) -> _ =
2612+ x ** yt + xt ** y
2613+ (x ** y, lin)
2614+
2615+ custom-linearization tiled_matmul matmul_linearization
2616+
26072617def (**.)(mat: n=>m=>Float, v: m=>Float) -> (n=>Float) given (n|Ix, m|Ix) =
26082618 for i. vdot(mat[i], v)
26092619def(.**)(v: n=>Float, mat: n=>m=>Float) -> (m=>Float) given (n|Ix, m|Ix) =
Original file line number Diff line number Diff line change @@ -498,3 +498,24 @@ deriv (\x. q x 1.0) 1.0
498498-- Correct derivative, based on the SomeTangent branch
499499deriv (\x. q x x) 1.0
500500> 3.
501+
502+ ------ Check custom linearization of matmul ------
503+
504+ amat = for i:(Fin 100) j:(Fin 100). n_to_f $ ordinal (i, j)
505+
506+ -- The derivative of matmul should give the same answers as a direct
507+ -- matmul (this checks that the custom derivative is not too busted).
508+
509+ def mmp'(m1:l=>m=>Float, m2:m=>n=>Float) -> l=>n=>Float given (l|Ix, m|Ix, n|Ix) =
510+ jvp (\m. m1 ** m) m2 m2
511+
512+ :p mmp'(amat, amat) ~~ naive_matmul(amat, amat)
513+ > True
514+
515+ -- Let's check the other orientation too
516+
517+ def mmp''(m1:l=>m=>Float, m2:m=>n=>Float) -> l=>n=>Float given (l|Ix, m|Ix, n|Ix) =
518+ jvp (\m. m ** m2) m1 m1
519+
520+ :p mmp''(amat, amat) ~~ naive_matmul(amat, amat)
521+ > True
You can’t perform that action at this time.
0 commit comments