@@ -197,5 +197,61 @@ def test_gguf_convert_to_gguf_gpu_layersCpu(self): # Quantizer on GPU, layers on
197197 def test_gguf_convert_to_gguf_gpu_layersDevice (self ): # Quantizer on GPU, layers on GPU
198198 self ._run_gguf_conversion_test (quantizer_device = self .device , gguf_cpu_offload = False )
199199
200+ # New test class for move_to_device
201+ from quantllm .quant .quantization_engine import move_to_device
202+ import torch .nn as nn
203+
204+ class TestMoveToDevice (unittest .TestCase ):
205+ def test_move_tensor_and_module (self ):
206+ """Test move_to_device with both torch.Tensor and torch.nn.Module."""
207+ target_device_str = "cuda" if torch .cuda .is_available () else "cpu"
208+ target_device = torch .device (target_device_str )
209+
210+ # 1. Create a simple torch.Tensor
211+ my_tensor = torch .randn (2 , 3 , device = "cpu" ) # Start on CPU
212+
213+ # 2. Create a simple torch.nn.Module
214+ my_module = nn .Linear (10 , 10 ).to ("cpu" ) # Start on CPU
215+
216+ # 4. Call move_to_device for the tensor and the module
217+ moved_tensor = move_to_device (my_tensor , target_device )
218+ moved_module = move_to_device (my_module , target_device )
219+
220+ # 5. Assert that the tensor is on the target device
221+ self .assertEqual (moved_tensor .device , target_device , "Tensor not moved to target device." )
222+
223+ # 6. Assert that the module is on the target device
224+ self .assertIsInstance (moved_module , nn .Module , "move_to_device did not return a Module." )
225+
226+ # Check device of a parameter
227+ if list (moved_module .parameters ()): # Check if module has parameters
228+ self .assertEqual (
229+ next (moved_module .parameters ()).device ,
230+ target_device ,
231+ "Module's parameters not moved to target device."
232+ )
233+ else : # Handle modules with no parameters (e.g. nn.ReLU()) if needed for future tests
234+ # For a simple Linear layer, this else block shouldn't be hit.
235+ # If testing with modules without parameters, one might check an attribute
236+ # or skip device check if not applicable. For nn.Linear, parameters exist.
237+ pass
238+
239+ # Test with force_copy=True for tensors
240+ another_tensor = torch .randn (2 ,3 , device = target_device )
241+ copied_tensor = move_to_device (another_tensor , target_device , force_copy = True )
242+ self .assertEqual (copied_tensor .device , target_device )
243+ if target_device_str == "cpu" : # On CPU, to() without copy=True might return same object if already on device
244+ pass # Data pointer check is more complex and not strictly necessary for device check
245+ else : # On CUDA, .to(device) typically creates a new tensor unless it's already there.
246+ # force_copy=True should ensure it's a different object.
247+ if another_tensor .is_cuda and copied_tensor .is_cuda : # Both on CUDA
248+ self .assertNotEqual (another_tensor .data_ptr (), copied_tensor .data_ptr (), "force_copy=True did not create a new tensor copy on CUDA." )
249+
250+ # Test moving a module already on the target device
251+ module_on_target = nn .Linear (5 ,5 ).to (target_device )
252+ moved_module_again = move_to_device (module_on_target , target_device )
253+ self .assertEqual (next (moved_module_again .parameters ()).device , target_device )
254+
255+
200256if __name__ == '__main__' :
201257 unittest .main ()
0 commit comments