@@ -505,52 +505,54 @@ def model_config_fp16():
505505 return deepcopy (ToyModel4 ().half ())
506506
507507
508- class ToyModelQuantized (torch .nn .Module ):
509- """
510- Three layer Linear model that has a quantized layer
508+ # QLinear class requires Nvidia GPU and cuda
509+ if torch .cuda .is_available ():
510+ class ToyModelQuantized (torch .nn .Module ):
511+ """
512+ Three layer Linear model that has a quantized layer
511513
512- Extends:
513- torch.nn.Module
514- """
514+ Extends:
515+ torch.nn.Module
516+ """
515517
516- def __init__ (self ):
517- super ().__init__ ()
518- kwargs = {"qcfg" : qconfig_init ()} # QLinear requires qconfig to work
519- self .first_layer = torch .nn .Linear (3 , 3 , bias = True )
520- self .second_layer = QLinear (3 , 3 , bias = True , ** kwargs )
521- self .third_layer = torch .nn .Linear (3 , 3 , bias = True )
518+ def __init__ (self ):
519+ super ().__init__ ()
520+ kwargs = {"qcfg" : qconfig_init ()} # QLinear requires qconfig to work
521+ self .first_layer = torch .nn .Linear (3 , 3 , bias = True )
522+ self .second_layer = QLinear (3 , 3 , bias = True , ** kwargs )
523+ self .third_layer = torch .nn .Linear (3 , 3 , bias = True )
522524
523- def forward (self , input_tensor ):
524- """
525- Forward func for Toy Model
525+ def forward (self , input_tensor ):
526+ """
527+ Forward func for Toy Model
526528
527- Args:
528- input_tensor (torch.FloatTensor): Tensor to operate on
529+ Args:
530+ input_tensor (torch.FloatTensor): Tensor to operate on
529531
530- Returns:
531- torch.FloatTensor:
532- """
533- out = self .first_layer (input_tensor )
534- out = self .second_layer (out )
535- out = self .third_layer (out )
536- return out
532+ Returns:
533+ torch.FloatTensor:
534+ """
535+ out = self .first_layer (input_tensor )
536+ out = self .second_layer (out )
537+ out = self .third_layer (out )
538+ return out
537539
538540
539- model_quantized_params = [ToyModelQuantized ()]
541+ model_quantized_params = [ToyModelQuantized ()]
540542
541543
542- @pytest .fixture (scope = "function" , params = model_quantized_params )
543- def model_quantized (request ):
544- """
545- Toy Model that has quantized layer
544+ @pytest .fixture (scope = "function" , params = model_quantized_params )
545+ def model_quantized (request ):
546+ """
547+ Toy Model that has quantized layer
546548
547- Args:
548- request (torch.nn.Module): Toy Model
549+ Args:
550+ request (torch.nn.Module): Toy Model
549551
550- Returns:
551- torch.nn.Module: Toy Model
552- """
553- return deepcopy (request .param )
552+ Returns:
553+ torch.nn.Module: Toy Model
554+ """
555+ return deepcopy (request .param )
554556
555557
556558# Get a model to test layer uniqueness
0 commit comments