@@ -2158,66 +2158,6 @@ def test_mat_vec_ops(self, func, x_shape, y_shape, make_expected):
21582158 expected = make_expected (x_val , y_val )
21592159 np .testing .assert_allclose (f (x_val , y_val ), expected )
21602160
2161- def test_matmul (self ):
2162- """Test matmul function with various input shapes."""
2163- rng = np .random .default_rng (seed = utt .fetch_seed ())
2164-
2165- # Test matrix-matrix
2166- x = matrix ()
2167- y = matrix ()
2168- z = matmul (x , y )
2169- f = function ([x , y ], z )
2170-
2171- x_val = random (3 , 4 , rng = rng ).astype (config .floatX )
2172- y_val = random (4 , 5 , rng = rng ).astype (config .floatX )
2173- np .testing .assert_allclose (f (x_val , y_val ), np .matmul (x_val , y_val ))
2174-
2175- # Test vector-matrix
2176- x = vector ()
2177- y = matrix ()
2178- z = matmul (x , y )
2179- f = function ([x , y ], z )
2180-
2181- x_val = random (3 , rng = rng ).astype (config .floatX )
2182- y_val = random (3 , 4 , rng = rng ).astype (config .floatX )
2183- np .testing .assert_allclose (f (x_val , y_val ), np .matmul (x_val , y_val ))
2184-
2185- # Test matrix-vector
2186- x = matrix ()
2187- y = vector ()
2188- z = matmul (x , y )
2189- f = function ([x , y ], z )
2190-
2191- x_val = random (3 , 4 , rng = rng ).astype (config .floatX )
2192- y_val = random (4 , rng = rng ).astype (config .floatX )
2193- np .testing .assert_allclose (f (x_val , y_val ), np .matmul (x_val , y_val ))
2194-
2195- # Test vector-vector
2196- x = vector ()
2197- y = vector ()
2198- z = matmul (x , y )
2199- f = function ([x , y ], z )
2200-
2201- x_val = random (3 , rng = rng ).astype (config .floatX )
2202- y_val = random (3 , rng = rng ).astype (config .floatX )
2203- np .testing .assert_allclose (f (x_val , y_val ), np .matmul (x_val , y_val ))
2204-
2205- # Test batched
2206- x = tensor3 ()
2207- y = tensor3 ()
2208- z = matmul (x , y )
2209- f = function ([x , y ], z )
2210-
2211- x_val = random (2 , 3 , 4 , rng = rng ).astype (config .floatX )
2212- y_val = random (2 , 4 , 5 , rng = rng ).astype (config .floatX )
2213- np .testing .assert_allclose (f (x_val , y_val ), np .matmul (x_val , y_val ))
2214-
2215- # Test error cases
2216- x = scalar ()
2217- y = scalar ()
2218- with pytest .raises (ValueError ):
2219- matmul (x , y )
2220-
22212161
22222162class TestTensordot :
22232163 def TensorDot (self , axes ):
0 commit comments