Skip to content

Commit 210cb4c

Browse files
committed
Use GPU for loading safetensors, disable export
1 parent e134b74 commit 210cb4c

File tree

2 files changed

+5
-3
lines changed

2 files changed

+5
-3
lines changed

modules/sd_models.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -147,8 +147,9 @@ def torch_load(model_filename, model_info, map_override=None):
147147
map_override=shared.weight_load_location if not map_override else map_override
148148
if(checkpoint_types[model_info.exttype] == 'safetensors'):
149149
# safely load weights
150-
# TODO: safetensors supports zero copy fast load to gpu, see issue #684
151-
return load_file(model_filename, device=map_override)
150+
# TODO: safetensors supports zero copy fast load to gpu, see issue #684.
151+
# GPU only for now, see https://github.com/huggingface/safetensors/issues/95
152+
return load_file(model_filename, device='cuda')
152153
else:
153154
return torch.load(model_filename, map_location=map_override)
154155

modules/ui.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1187,7 +1187,8 @@ def create_ui(wrap_gradio_gpu_call):
11871187
interp_amount = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, label='Multiplier (M) - set to 0 to get model A', value=0.3)
11881188
interp_method = gr.Radio(choices=["Weighted sum", "Add difference"], value="Weighted sum", label="Interpolation Method")
11891189
save_as_half = gr.Checkbox(value=False, label="Save as float16")
1190-
save_as_safetensors = gr.Checkbox(value=False, label="Save as safetensors format")
1190+
# invisible until feature can be verified
1191+
save_as_safetensors = gr.Checkbox(value=False, label="Save as safetensors format", visible=False)
11911192
modelmerger_merge = gr.Button(elem_id="modelmerger_merge", label="Merge", variant='primary')
11921193

11931194
with gr.Column(variant='panel'):

0 commit comments

Comments
 (0)