Skip to content

Commit ab7f381

Browse files
committed
fixes
1 parent 24a46cc commit ab7f381

File tree

2 files changed

+7
-4
lines changed

2 files changed

+7
-4
lines changed

benchmarks/benchmarking_flux.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,20 @@
11
import torch
2+
from benchmarking_utils import BenchmarkMixin
23

34
from diffusers import FluxTransformer2DModel
45
from diffusers.utils.testing_utils import torch_device
56

6-
from .benchmarking_utils import BenchmarkMixin
7-
87

98
class BenchmarkFlux(BenchmarkMixin):
109
model_class = FluxTransformer2DModel
1110
compile_kwargs = {"fullgraph": True, "mode": "max-autotune"}
1211

1312
def get_model_init_dict(self):
14-
return {"ckpt_id": "black-forest-labs/FLUX.1-dev", "subfolder": "transformer", "torch_dtype": torch.bfloat16}
13+
return {
14+
"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-dev",
15+
"subfolder": "transformer",
16+
"torch_dtype": torch.bfloat16,
17+
}
1518

1619
def initialize_model(self):
1720
model = self.model_class.from_pretrained(**self.get_model_init_dict())

benchmarks/benchmarking_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import gc
22

33
import torch
4-
from torch.utils.benchmark import benchmark
4+
import torch.utils.benchmark as benchmark
55

66
from diffusers.models.modeling_utils import ModelMixin
77
from diffusers.utils.testing_utils import require_torch_gpu

0 commit comments

Comments
 (0)