@@ -67,7 +67,7 @@ def forward(ctx,
6767 if input_act .is_cpu :
6868 raise RuntimeError ("[Error] The input `input_act` of permute_topK op is on the device: CPU!" )
6969 if indices .is_cpu :
70- print ( "[Warning] The input `indices` of permute_topK op is on the device: CPU!", file = stderr )
70+ warnings . warn ( " The input `indices` of permute_topK op is on the device: CPU!" )
7171 expert_for_rows = expert_for_rows .cuda ()
7272
7373 # Shape check
@@ -77,16 +77,16 @@ def forward(ctx,
7777
7878 # Data type check
7979 if indices .dtype != torch .int32 :
80- print (f"[Warning] The data type of the input `indices` of permute_topK op is { indices .dtype } ! "
81- "The recommended type is torch.int32." , file = stderr )
80+ warnings . warn (f"The data type of the input `indices` of permute_topK op is { indices .dtype } ! "
81+ "The recommended type is torch.int32." )
8282 indices = indices .to (torch .int32 )
8383
8484 # Contiguous check
8585 if not input_act .is_contiguous ():
86- print ( "[Warning] The input `input_act` of permute_topK op is discontiguous!", file = stderr )
86+ warnings . warn ( " The input `input_act` of permute_topK op is discontiguous!" )
8787 input_act = input_act .contiguous ()
8888 if not indices .is_contiguous ():
89- print ( "[Warning] The input `indices` of permute_topK op is discontiguous!", file = stderr )
89+ warnings . warn ( " The input `indices` of permute_topK op is discontiguous!" )
9090 indices = indices .contiguous ()
9191
9292 num_topK = indices .size (1 )
@@ -159,10 +159,10 @@ def forward(ctx,
159159 if input_act .is_cpu :
160160 raise RuntimeError ("[Error] The input `input_act` of unpermute_topK op is on the device: CPU!" )
161161 if row_id_map .is_cpu :
162- print ( "[Warning] The input `row_id_map` of unpermute_topK op is on the device: CPU!", file = stderr )
162+ warnings . warn ( " The input `row_id_map` of unpermute_topK op is on the device: CPU!" )
163163 row_id_map = row_id_map .cuda ()
164164 if probs .is_cpu :
165- print ( "[Warning] The input `probs` of unpermute_topK op is on the device: CPU!", file = stderr )
165+ warnings . warn ( " The input `probs` of unpermute_topK op is on the device: CPU!" )
166166 probs = probs .cuda ()
167167
168168 # Shape check
@@ -175,23 +175,23 @@ def forward(ctx,
175175
176176 # Data type check
177177 if row_id_map .dtype != torch .int32 :
178- print (f"[Warning] The data type of the input `row_id_map` of unpermute_topK op is { row_id_map .dtype } ! "
179- "The recommended type is torch.int32." , file = stderr )
178+ warnings . warn (f"The data type of the input `row_id_map` of unpermute_topK op is { row_id_map .dtype } ! "
179+ "The recommended type is torch.int32." )
180180 row_id_map = row_id_map .to (torch .int32 )
181181 if probs .dtype != torch .float32 :
182- print (f"[Warning] The data type of the input `probs` of unpermute_topK op is { probs .dtype } ! "
183- "The recommended type is torch.float32." , file = stderr )
182+ warnings . warn (f"The data type of the input `probs` of unpermute_topK op is { probs .dtype } ! "
183+ "The recommended type is torch.float32." )
184184 probs = probs .to (torch .float32 )
185185
186186 # Contiguous check
187187 if not input_act .is_contiguous ():
188- print ( "[Warning] The input `input_act` of unpermute_topK op is discontiguous!", file = stderr )
188+ warnings . warn ( " The input `input_act` of unpermute_topK op is discontiguous!" )
189189 input_act = input_act .contiguous ()
190190 if not row_id_map .is_contiguous ():
191- print ( "[Warning] The input `row_id_map` of unpermute_topK op is discontiguous!", file = stderr )
191+ warnings . warn ( " The input `row_id_map` of unpermute_topK op is discontiguous!" )
192192 row_id_map = row_id_map .contiguous ()
193193 if not probs .is_contiguous ():
194- print ( "[Warning] The input `probs` of unpermute_topK op is discontiguous!", file = stderr )
194+ warnings . warn ( " The input `probs` of unpermute_topK op is discontiguous!" )
195195 probs = probs .contiguous ()
196196
197197 num_tokens = probs .size (0 )
0 commit comments