@@ -56,6 +56,23 @@ def test_computes_product_of_radial_basis_function(self):
5656 res = kernel (a , b ).evaluate ()
5757 self .assertLess (torch .norm (res - actual ), 2e-5 )
5858
59+ def test_computes_product_of_radial_basis_function_batch (self ):
60+ a = torch .tensor ([4 , 2 , 8 ], dtype = torch .float ).view (3 , 1 )
61+ b = torch .tensor ([0 , 2 ], dtype = torch .float ).view (2 , 1 )
62+ lengthscale = 2
63+
64+ kernel_1 = RBFKernel (batch_shape = torch .Size ([4 ])).initialize (lengthscale = lengthscale )
65+ kernel_2 = RBFKernel ().initialize (lengthscale = lengthscale )
66+ kernel = kernel_1 * kernel_2
67+
68+ actual = torch .tensor ([[16 , 4 ], [4 , 0 ], [64 , 36 ]], dtype = torch .float )
69+ actual = actual .mul_ (- 0.5 ).div_ (lengthscale ** 2 ).exp () ** 2
70+ actual = actual .repeat (4 , 1 , 1 )
71+
72+ kernel .eval ()
73+ res = kernel (a , b ).evaluate ()
74+ self .assertLess (torch .norm (res - actual ), 2e-5 )
75+
5976 def test_computes_sum_of_radial_basis_function (self ):
6077 a = torch .tensor ([4 , 2 , 8 ], dtype = torch .float ).view (3 , 1 )
6178 b = torch .tensor ([0 , 2 ], dtype = torch .float ).view (2 , 1 )
0 commit comments