Skip to content

Commit 289a43e

Browse files
axchemilyfertig
andcommitted
Actually code the custom linearization for matmul
so that we don't embarrass ourselves too much with how poorly AD does thereon. Co-authored-by: Emily Fertig <[email protected]>
1 parent 89343a2 commit 289a43e

File tree

2 files changed

+32
-1
lines changed

2 files changed

+32
-1
lines changed

lib/prelude.dx

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2576,6 +2576,7 @@ def tile(
25762576
body (FullTileIx(n, tile_size, tile_ix'))
25772577
body (CodaIx(n, coda_offset, coda_size))
25782578

2579+
@noinline
25792580
def tiled_matmul(
25802581
x: l=>m=>Float,
25812582
y: m=>n=>Float
@@ -2601,9 +2602,18 @@ def (**)(
26012602
x: l=>m=>Float,
26022603
y: m=>n=>Float
26032604
) -> l=>n=>Float given (l|Ix, m|Ix, n|Ix) =
2604-
-- TODO(https://github.com/google-research/dex-lang/issues/1212) Replace with tiled_matmul.
26052605
tiled_matmul(x, y)
26062606

2607+
def matmul_linearization(
2608+
x: l=>m=>Float,
2609+
y: m=>n=>Float
2610+
) -> _ given (l|Ix, m|Ix, n|Ix) =
2611+
def lin(xt: l=>m=>Float, yt: m=>n=>Float) -> _ =
2612+
x ** yt + xt ** y
2613+
(x ** y, lin)
2614+
2615+
custom-linearization tiled_matmul matmul_linearization
2616+
26072617
def (**.)(mat: n=>m=>Float, v: m=>Float) -> (n=>Float) given (n|Ix, m|Ix) =
26082618
for i. vdot(mat[i], v)
26092619
def(.**)(v: n=>Float, mat: n=>m=>Float) -> (m=>Float) given (n|Ix, m|Ix) =

tests/ad-tests.dx

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -498,3 +498,24 @@ deriv (\x. q x 1.0) 1.0
498498
-- Correct derivative, based on the SomeTangent branch
499499
deriv (\x. q x x) 1.0
500500
> 3.
501+
502+
------ Check custom linearization of matmul ------
503+
504+
amat = for i:(Fin 100) j:(Fin 100). n_to_f $ ordinal (i, j)
505+
506+
-- The derivative of matmul should give the same answers as a direct
507+
-- matmul (this checks that the custom derivative is not too busted).
508+
509+
def mmp'(m1:l=>m=>Float, m2:m=>n=>Float) -> l=>n=>Float given (l|Ix, m|Ix, n|Ix) =
510+
jvp (\m. m1 ** m) m2 m2
511+
512+
:p mmp'(amat, amat) ~~ naive_matmul(amat, amat)
513+
> True
514+
515+
-- Let's check the other orientation too
516+
517+
def mmp''(m1:l=>m=>Float, m2:m=>n=>Float) -> l=>n=>Float given (l|Ix, m|Ix, n|Ix) =
518+
jvp (\m. m ** m2) m1 m1
519+
520+
:p mmp''(amat, amat) ~~ naive_matmul(amat, amat)
521+
> True

0 commit comments

Comments
 (0)