Skip to content
6 changes: 6 additions & 0 deletions src/diffusers/models/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -803,6 +803,12 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
subfolder=subfolder or "",
)
if hf_quantizer is not None:
is_torchao_quantization_method = quantization_config.quant_method == QuantizationMethod.TORCHAO
Copy link
Collaborator

@yiyixuxu yiyixuxu Dec 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we consolidate with this bnb check (remove the bnb check and extend this check for any quantization method)

is_bnb_quantization_method = hf_quantizer.quantization_config.quant_method.value == "bitsandbytes"

this should not specific to any quantisation method, no? I run this test, for non-sharded checkpoint, both works for shared checkpoint, both throw same error

from diffusers import FluxPipeline, FluxTransformer2DModel, TorchAoConfig, BitsAndBytesConfig
import torch

sharded_model_id = "black-forest-labs/Flux.1-Dev"
single_model_path = "/raid/yiyi/flux_model_single"
dtype = torch.bfloat16

# create a non-sharded checkpoint
# transformer = FluxTransformer2DModel.from_pretrained(
#     model_id,
#     subfolder="transformer",
#     torch_dtype=dtype,
# )
# transformer.save_pretrained(single_model_path, max_shard_size="100GB")

torch_ao_quantization_config = TorchAoConfig("int8wo")
bnb_quantization_config = BitsAndBytesConfig(load_in_8bit=True)

print(f" testing non-sharded checkpoint")
transformer = FluxTransformer2DModel.from_pretrained(
    single_model_path,
    quantization_config=torch_ao_quantization_config,
    device_map="auto",
    torch_dtype=dtype,
)

print(f"torchao hf_device_map: {transformer.hf_device_map}")

transformer = FluxTransformer2DModel.from_pretrained(
    single_model_path, 
    quantization_config=bnb_quantization_config,
    device_map="auto",
    torch_dtype=dtype,
)
print(f"bnb hf_device_map: {transformer.hf_device_map}")


print(f" testing sharded checkpoint")
## sharded checkpoint
try:
    transformer = FluxTransformer2DModel.from_pretrained(
        sharded_model_id, 
        subfolder="transformer",
        quantization_config=torch_ao_quantization_config,
        device_map="auto",
        torch_dtype=dtype,
    )
    print(f"torchao: {transformer.hf_device_map}")
except Exception as e:
    print(f"error: {e}")

try:
    transformer = FluxTransformer2DModel.from_pretrained(
        sharded_model_id,
        subfolder="transformer",
        quantization_config=bnb_quantization_config,
        device_map="auto",
        torch_dtype=dtype,
)
    print(f"bnb hf_device_map: {transformer.hf_device_map}")
except Exception as e:
    print(f"error: {e}")

Copy link
Contributor Author

@a-r-r-o-w a-r-r-o-w Dec 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think non-sharded works for both, no? non-sharded checkpoint only seems to work torchao at the moment. These are my results:

method/shard sharded non-sharded
torchao fails works
bnb fails fails

I tried with your code as well and get the following error when using BnB with unsharded on this branch:

NotImplementedError: Currently, `device_map` is automatically inferred for quantized bitsandbytes models. Support for providing `device_map` as an input will be added in the future.

Whatever the automatic infer of device_map thing is, we are still unable to pass device_map manually when state dict is sharded/unsharded, so I would put it in same bucket as failing.

Consolidating the checks together sounds good. Will update

if device_map is not None and is_torchao_quantization_method:
raise NotImplementedError(
"Loading sharded checkpoints, while passing `device_map`, is not supported with `torchao` quantization. This will be supported in the near future."
)

model_file = _merge_sharded_checkpoints(sharded_ckpt_cached_folder, sharded_metadata)
logger.info("Merged sharded checkpoints as `hf_quantizer` is not None.")
is_sharded = False
Expand Down
Loading