Skip to content

Commit cc0a38a

Browse files
committed
fixes
1 parent ab7f381 commit cc0a38a

File tree

2 files changed

+7
-11
lines changed

2 files changed

+7
-11
lines changed

benchmarks/benchmarking_flux.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
class BenchmarkFlux(BenchmarkMixin):
99
model_class = FluxTransformer2DModel
10-
compile_kwargs = {"fullgraph": True, "mode": "max-autotune"}
10+
compile_kwargs = {"fullgraph": True}
1111

1212
def get_model_init_dict(self):
1313
return {
@@ -29,8 +29,8 @@ def get_input_dict(self):
2929
pooled_prompt_embeds = torch.randn(1, 768, device=torch_device, dtype=torch.bfloat16)
3030
image_ids = torch.ones(512, 3, device=torch_device, dtype=torch.bfloat16)
3131
text_ids = torch.ones(4096, 3, device=torch_device, dtype=torch.bfloat16)
32-
timestep = torch.tensor([1.0], device=torch_device)
33-
guidance = torch.tensor([1.0], device=torch_device)
32+
timestep = torch.tensor([1.0], device=torch_device, dtype=torch.bfloat16)
33+
guidance = torch.tensor([1.0], device=torch_device, dtype=torch.bfloat16)
3434

3535
return {
3636
"hidden_states": hidden_states,

benchmarks/benchmarking_utils.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ def benchmark_fn(f, *args, **kwargs):
1111
t0 = benchmark.Timer(
1212
stmt="f(*args, **kwargs)",
1313
globals={"args": args, "kwargs": kwargs, "f": f},
14-
num_threads=torch.get_num_threads(),
14+
num_threads=1,
1515
)
1616
return f"{(t0.blocked_autorange().mean):.3f}"
1717

@@ -53,10 +53,6 @@ def run_benchmark(self):
5353
model = self.initialize_model() # Takes care of device placement.
5454
input_dict = self.get_input_dict() # Takes care of device placement.
5555

56-
# warmup
57-
for _ in range(5):
58-
_ = model(**input_dict)
59-
6056
time = benchmark_fn(lambda model, input_dict: model(**input_dict), model, input_dict)
6157
memory = torch.cuda.max_memory_allocated() / (1024**3)
6258
memory = float(f"{memory:.2f}")
@@ -69,9 +65,9 @@ def run_benchmark(self):
6965
compile_stats = None
7066
if self.compile_kwargs is not None:
7167
model = self.initialize_model()
72-
with torch._inductor.utils.fresh_inductor_cache():
73-
model.compile(**self.compile_kwargs)
74-
time = benchmark_fn(lambda model, input_dict: model(**input_dict), model, input_dict)
68+
input_dict = self.get_input_dict()
69+
model.compile(**self.compile_kwargs)
70+
time = benchmark_fn(lambda model, input_dict: model(**input_dict), model, input_dict)
7571
memory = torch.cuda.max_memory_allocated() / (1024**3)
7672
memory = float(f"{memory:.2f}")
7773
compile_stats = {"time": time, "memory": memory}

0 commit comments

Comments
 (0)