Skip to content

Commit 17fb8ba

Browse files
committed
log formatting.
1 parent 7b3a1f2 commit 17fb8ba

File tree

1 file changed

+14
-14
lines changed

1 file changed

+14
-14
lines changed

grouped_gemm/ops.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def forward(ctx,
6767
if input_act.is_cpu:
6868
raise RuntimeError("[Error] The input `input_act` of permute_topK op is on the device: CPU!")
6969
if indices.is_cpu:
70-
print("[Warning] The input `indices` of permute_topK op is on the device: CPU!", file=stderr)
70+
warnings.warn("The input `indices` of permute_topK op is on the device: CPU!")
7171
expert_for_rows = expert_for_rows.cuda()
7272

7373
# Shape check
@@ -77,16 +77,16 @@ def forward(ctx,
7777

7878
# Data type check
7979
if indices.dtype != torch.int32:
80-
print(f"[Warning] The data type of the input `indices` of permute_topK op is {indices.dtype}! "
81-
"The recommended type is torch.int32.", file=stderr)
80+
warnings.warn(f"The data type of the input `indices` of permute_topK op is {indices.dtype}! "
81+
"The recommended type is torch.int32.")
8282
indices = indices.to(torch.int32)
8383

8484
# Contiguous check
8585
if not input_act.is_contiguous():
86-
print("[Warning] The input `input_act` of permute_topK op is discontiguous!", file=stderr)
86+
warnings.warn("The input `input_act` of permute_topK op is discontiguous!")
8787
input_act = input_act.contiguous()
8888
if not indices.is_contiguous():
89-
print("[Warning] The input `indices` of permute_topK op is discontiguous!", file=stderr)
89+
warnings.warn("The input `indices` of permute_topK op is discontiguous!")
9090
indices = indices.contiguous()
9191

9292
num_topK = indices.size(1)
@@ -159,10 +159,10 @@ def forward(ctx,
159159
if input_act.is_cpu:
160160
raise RuntimeError("[Error] The input `input_act` of unpermute_topK op is on the device: CPU!")
161161
if row_id_map.is_cpu:
162-
print("[Warning] The input `row_id_map` of unpermute_topK op is on the device: CPU!", file=stderr)
162+
warnings.warn("The input `row_id_map` of unpermute_topK op is on the device: CPU!")
163163
row_id_map = row_id_map.cuda()
164164
if probs.is_cpu:
165-
print("[Warning] The input `probs` of unpermute_topK op is on the device: CPU!", file=stderr)
165+
warnings.warn("The input `probs` of unpermute_topK op is on the device: CPU!")
166166
probs = probs.cuda()
167167

168168
# Shape check
@@ -175,23 +175,23 @@ def forward(ctx,
175175

176176
# Data type check
177177
if row_id_map.dtype != torch.int32:
178-
print(f"[Warning] The data type of the input `row_id_map` of unpermute_topK op is {row_id_map.dtype}! "
179-
"The recommended type is torch.int32.", file=stderr)
178+
warnings.warn(f"The data type of the input `row_id_map` of unpermute_topK op is {row_id_map.dtype}! "
179+
"The recommended type is torch.int32.")
180180
row_id_map = row_id_map.to(torch.int32)
181181
if probs.dtype != torch.float32:
182-
print(f"[Warning] The data type of the input `probs` of unpermute_topK op is {probs.dtype}! "
183-
"The recommended type is torch.float32.", file=stderr)
182+
warnings.warn(f"The data type of the input `probs` of unpermute_topK op is {probs.dtype}! "
183+
"The recommended type is torch.float32.")
184184
probs = probs.to(torch.float32)
185185

186186
# Contiguous check
187187
if not input_act.is_contiguous():
188-
print("[Warning] The input `input_act` of unpermute_topK op is discontiguous!", file=stderr)
188+
warnings.warn("The input `input_act` of unpermute_topK op is discontiguous!")
189189
input_act = input_act.contiguous()
190190
if not row_id_map.is_contiguous():
191-
print("[Warning] The input `row_id_map` of unpermute_topK op is discontiguous!", file=stderr)
191+
warnings.warn("The input `row_id_map` of unpermute_topK op is discontiguous!")
192192
row_id_map = row_id_map.contiguous()
193193
if not probs.is_contiguous():
194-
print("[Warning] The input `probs` of unpermute_topK op is discontiguous!", file=stderr)
194+
warnings.warn("The input `probs` of unpermute_topK op is discontiguous!")
195195
probs = probs.contiguous()
196196

197197
num_tokens = probs.size(0)

0 commit comments

Comments
 (0)