|
11 | 11 | import tempfile |
12 | 12 | import typing |
13 | 13 | import unittest |
14 | | -from functools import partial |
15 | 14 | from pathlib import Path |
16 | 15 | from typing import * # noqa: F403 |
17 | 16 |
|
@@ -4157,148 +4156,6 @@ def test_any_output_is_alias_to_input_or_output(self): |
4157 | 4156 | ) |
4158 | 4157 | ) |
4159 | 4158 |
|
4160 | | - def test_library_get_kernel(self): |
4161 | | - """Test registering a custom kernel, using it, then deregistering and verifying error.""" |
4162 | | - |
4163 | | - # Register a dummy kernel for arange to the CPU key that returns a tensor of ones |
4164 | | - def dummy_arange_cpu( |
4165 | | - dispatch_keys, |
4166 | | - start, |
4167 | | - end, |
4168 | | - dtype=None, |
4169 | | - layout=torch.strided, |
4170 | | - device=None, |
4171 | | - pin_memory=False, |
4172 | | - ): |
4173 | | - size = max(0, int(end - start)) |
4174 | | - return torch.ones(size, dtype=dtype, device=device) |
4175 | | - |
4176 | | - with torch.library._scoped_library("aten", "IMPL") as lib: |
4177 | | - lib.impl("arange.start", dummy_arange_cpu, "CPU", with_keyset=True) |
4178 | | - |
4179 | | - kernel = torch.library.get_kernel("aten::arange.start", "CPU") |
4180 | | - dispatch_keys = torch._C.DispatchKeySet(torch._C.DispatchKey.CPU) |
4181 | | - result = kernel.call_boxed(dispatch_keys, 0, 5) |
4182 | | - |
4183 | | - self.assertEqual(result, torch.ones(5)) |
4184 | | - |
4185 | | - # The kernel should now be invalidated after exiting the scoped_library context |
4186 | | - with self.assertRaisesRegex(RuntimeError, "has been invalidated"): |
4187 | | - kernel.call_boxed(dispatch_keys, 0, 5) |
4188 | | - |
4189 | | - def test_library_get_kernel_with_conditional_dispatch(self): |
4190 | | - """Test registering a custom kernel with conditional dispatch logic.""" |
4191 | | - |
4192 | | - def conditional_arange_cpu1( |
4193 | | - original_kernel, |
4194 | | - dispatch_keys, |
4195 | | - start, |
4196 | | - end, |
4197 | | - dtype=None, |
4198 | | - layout=torch.strided, |
4199 | | - device=None, |
4200 | | - pin_memory=False, |
4201 | | - ): |
4202 | | - # If end is even, use the original kernel, otherwise return ones tensor |
4203 | | - if end % 2 == 0: |
4204 | | - op_handle = torch.ops.aten.arange.start._handle |
4205 | | - return original_kernel.call_boxed( |
4206 | | - dispatch_keys, |
4207 | | - start, |
4208 | | - end, |
4209 | | - dtype=dtype, |
4210 | | - layout=layout, |
4211 | | - device=device, |
4212 | | - pin_memory=pin_memory, |
4213 | | - ) |
4214 | | - else: |
4215 | | - size = max(0, int(end - start)) |
4216 | | - return torch.ones(size, dtype=dtype, device=device) |
4217 | | - |
4218 | | - def conditional_arange_cpu2( |
4219 | | - original_kernel, |
4220 | | - dispatch_keys, |
4221 | | - start, |
4222 | | - end, |
4223 | | - dtype=None, |
4224 | | - layout=torch.strided, |
4225 | | - device=None, |
4226 | | - pin_memory=False, |
4227 | | - ): |
4228 | | - # If start is even, use the original kernel, otherwise return twos tensor |
4229 | | - if start % 2 == 0: |
4230 | | - op_handle = torch.ops.aten.arange.start._handle |
4231 | | - return original_kernel.call_boxed( |
4232 | | - dispatch_keys, |
4233 | | - start, |
4234 | | - end, |
4235 | | - dtype=dtype, |
4236 | | - layout=layout, |
4237 | | - device=device, |
4238 | | - pin_memory=pin_memory, |
4239 | | - ) |
4240 | | - else: |
4241 | | - size = max(0, int(end - start)) |
4242 | | - return torch.empty(size, dtype=dtype, device=device).fill_(2) |
4243 | | - |
4244 | | - original_kernel = torch.library.get_kernel("aten::arange.start", "CPU") |
4245 | | - expected_result1, expected_result2 = torch.ones(5), torch.arange(0, 6) |
4246 | | - expected_result3, expected_result4, expected_result5 = ( |
4247 | | - torch.ones(5), |
4248 | | - torch.arange(0, 6), |
4249 | | - torch.ones(5).fill_(2), |
4250 | | - ) |
4251 | | - |
4252 | | - with torch.library._scoped_library("aten", "IMPL") as lib2: |
4253 | | - with torch.library._scoped_library("aten", "IMPL") as lib1: |
4254 | | - lib1.impl( |
4255 | | - "arange.start", |
4256 | | - partial(conditional_arange_cpu1, original_kernel), |
4257 | | - "CPU", |
4258 | | - with_keyset=True, |
4259 | | - ) |
4260 | | - |
4261 | | - self.assertEqual(torch.arange(0, 5), expected_result1) |
4262 | | - self.assertEqual(torch.arange(0, 6), expected_result2) |
4263 | | - new_original_kernel = torch.library.get_kernel( |
4264 | | - "aten::arange.start", "CPU" |
4265 | | - ) |
4266 | | - lib2.impl( |
4267 | | - "arange.start", |
4268 | | - partial(conditional_arange_cpu2, new_original_kernel), |
4269 | | - "CPU", |
4270 | | - allow_override=True, |
4271 | | - with_keyset=True, |
4272 | | - ) |
4273 | | - |
4274 | | - self.assertEqual(torch.arange(0, 5), expected_result3) |
4275 | | - self.assertEqual(torch.arange(0, 6), expected_result4) |
4276 | | - self.assertEqual(torch.arange(1, 6), expected_result5) |
4277 | | - |
4278 | | - # The kernel should now be invalidated after destroying lib1 |
4279 | | - with self.assertRaisesRegex(RuntimeError, "has been invalidated"): |
4280 | | - torch.arange(0, 5) |
4281 | | - |
4282 | | - # Should still work after destroying lib1 |
4283 | | - self.assertEqual(torch.arange(1, 6), expected_result5) |
4284 | | - |
4285 | | - def test_library_get_kernel_invalid(self): |
4286 | | - """Test that get_kernel raises an error when no kernel is available.""" |
4287 | | - with torch.library._scoped_library("test_invalid_kernel", "DEF") as lib: |
4288 | | - lib.define("cpu_only_op(Tensor x) -> Tensor") |
4289 | | - lib.impl("cpu_only_op", lambda x: x * 2, "CPU") |
4290 | | - |
4291 | | - cpu_kernel = torch.library.get_kernel( |
4292 | | - "test_invalid_kernel::cpu_only_op", "CPU" |
4293 | | - ) |
4294 | | - self.assertIsNotNone(cpu_kernel) |
4295 | | - |
4296 | | - # CUDA should fail at the isValid() check since no CUDA kernel exists |
4297 | | - with self.assertRaisesRegex( |
4298 | | - RuntimeError, "no kernel for CUDA for test_invalid_kernel::cpu_only_op" |
4299 | | - ): |
4300 | | - torch.library.get_kernel("test_invalid_kernel::cpu_only_op", "CUDA") |
4301 | | - |
4302 | 4159 |
|
4303 | 4160 | class MiniOpTestOther(CustomOpTestCaseBase): |
4304 | 4161 | test_ns = "mini_op_test" |
|
0 commit comments