@@ -114,11 +114,11 @@ def __init__( # pylint: disable=too-many-arguments
114
114
self .norm2 : torch .nn .Module = torch .nn .LayerNorm (embedding_dim )
115
115
116
116
self .bias : torch .Tensor
117
- self .register_buffer (' bias' , torch .zeros ([self .routed_num ]))
117
+ self .register_buffer (" bias" , torch .zeros ([self .routed_num ]))
118
118
self .accumulater : torch .Tensor
119
- self .register_buffer (' accumulater' , torch .zeros ([self .routed_num ]))
119
+ self .register_buffer (" accumulater" , torch .zeros ([self .routed_num ]))
120
120
self .count : torch .Tensor
121
- self .register_buffer (' count' , torch .zeros ([]))
121
+ self .register_buffer (" count" , torch .zeros ([]))
122
122
123
123
def forward (
124
124
self ,
@@ -320,9 +320,9 @@ def __init__( # pylint: disable=too-many-arguments
320
320
if isinstance (ordering , int ) and ordering == - 1 :
321
321
ordering = list (reversed (range (self .sites )))
322
322
self .ordering : torch .Tensor
323
- self .register_buffer (' ordering' , torch .tensor (ordering , dtype = torch .int64 ))
323
+ self .register_buffer (" ordering" , torch .tensor (ordering , dtype = torch .int64 ))
324
324
self .ordering_reversed : torch .Tensor
325
- self .register_buffer (' ordering_reversed' , torch .scatter (torch .zeros (self .sites , dtype = torch .int64 ), 0 , self .ordering , torch .arange (self .sites , dtype = torch .int64 )))
325
+ self .register_buffer (" ordering_reversed" , torch .scatter (torch .zeros (self .sites , dtype = torch .int64 ), 0 , self .ordering , torch .arange (self .sites , dtype = torch .int64 )))
326
326
327
327
# Dummy Parameter for Device and Dtype Retrieval
328
328
# This parameter is used to infer the device and dtype of the model.
@@ -653,9 +653,9 @@ def __init__( # pylint: disable=too-many-arguments
653
653
if isinstance (ordering , int ) and ordering == - 1 :
654
654
ordering = list (reversed (range (self .sites )))
655
655
self .ordering : torch .Tensor
656
- self .register_buffer (' ordering' , torch .tensor (ordering , dtype = torch .int64 ))
656
+ self .register_buffer (" ordering" , torch .tensor (ordering , dtype = torch .int64 ))
657
657
self .ordering_reversed : torch .Tensor
658
- self .register_buffer (' ordering_reversed' , torch .scatter (torch .zeros (self .sites , dtype = torch .int64 ), 0 , self .ordering , torch .arange (self .sites , dtype = torch .int64 )))
658
+ self .register_buffer (" ordering_reversed" , torch .scatter (torch .zeros (self .sites , dtype = torch .int64 ), 0 , self .ordering , torch .arange (self .sites , dtype = torch .int64 )))
659
659
660
660
# Dummy Parameter for Device and Dtype Retrieval
661
661
# This parameter is used to infer the device and dtype of the model.
0 commit comments