Skip to content

Commit 17c6316

Browse files
gchalumpfacebook-github-bot
authored andcommitted
fix tbe reporter (pytorch#4882)
Summary: Pull Request resolved: pytorch#4882 - fix mean type - fix tensor type Reviewed By: YanXiong-Meta Differential Revision: D82248514 fbshipit-source-id: 7ff43e5d3ba8ae36cbda12fa464b595deceb7637
1 parent f9ccd01 commit 17c6316

File tree

2 files changed

+8
-3
lines changed

2 files changed

+8
-3
lines changed

fbgemm_gpu/fbgemm_gpu/tbe/bench/tbe_data_config_bench_helper.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,7 @@ def generate_requests(
191191

192192
# Generate indices
193193
all_indices = _generate_indices(tbe_data_config, iters, Bs, L_offsets)
194+
all_indices = all_indices.to(get_device())
194195

195196
# Build TBE requests
196197
if tbe_data_config.variable_B() or tbe_data_config.variable_L():

fbgemm_gpu/fbgemm_gpu/tbe/stats/bench_params_reporter.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -177,15 +177,19 @@ def extract_params(
177177
E = (
178178
Es[0]
179179
if len(set(Es)) == 1
180-
else torch.ceil(torch.mean(torch.tensor(feature_rows)))
180+
else torch.ceil(
181+
torch.mean(torch.tensor(feature_rows, dtype=torch.float))
182+
).item()
181183
)
182184
# Set mixed_dim to be True if there are multiple dims
183185
mixed_dim = len(set(Ds)) > 1
184186
# Set D to be the mean of the dims to avoid biasing
185187
D = (
186188
Ds[0]
187189
if not mixed_dim
188-
else torch.ceil(torch.mean(torch.tensor(feature_dims)))
190+
else torch.ceil(
191+
torch.mean(torch.tensor(feature_dims, dtype=torch.float))
192+
).item()
189193
)
190194

191195
# Compute indices distribution parameters
@@ -198,7 +202,7 @@ def extract_params(
198202

199203
# Compute batch parameters
200204
batch_params = BatchParams(
201-
B=((offsets.numel() - 1) // T),
205+
B=int((offsets.numel() - 1) // T),
202206
sigma_B=(
203207
int(
204208
torch.ceil(

0 commit comments

Comments
 (0)