File tree Expand file tree Collapse file tree 2 files changed +8
-3
lines changed
Expand file tree Collapse file tree 2 files changed +8
-3
lines changed Original file line number Diff line number Diff line change @@ -259,7 +259,8 @@ def get_cached_module_file(
259259 local_files_only = local_files_only ,
260260 use_auth_token = False ,
261261 )
262- submodule = "local"
262+ submodule = "git"
263+ module_file = pretrained_model_name_or_path + ".py"
263264 except EnvironmentError :
264265 logger .error (f"Could not locate the { module_file } inside { pretrained_model_name_or_path } ." )
265266 raise
@@ -288,7 +289,7 @@ def get_cached_module_file(
288289 full_submodule = DIFFUSERS_DYNAMIC_MODULE_NAME + os .path .sep + submodule
289290 create_dynamic_module (full_submodule )
290291 submodule_path = Path (HF_MODULES_CACHE ) / full_submodule
291- if submodule == "local" :
292+ if submodule == "local" or submodule == "git" :
292293 # We always copy local files (we could hash the file to see if there was a change, and give them the name of
293294 # that hash, to only copy when there is a modification but it seems overkill for now).
294295 # The only reason we do the copy is to avoid putting too many folders in sys.path.
Original file line number Diff line number Diff line change @@ -112,18 +112,22 @@ def test_local_custom_pipeline(self):
112112 assert output_str == "This is a local test"
113113
114114 @slow
115+ @unittest .skipIf (torch_device == "cpu" , "Stable diffusion is supposed to run on GPU" )
115116 def test_load_pipeline_from_git (self ):
116117 clip_model_id = "laion/CLIP-ViT-B-32-laion2B-s34B-b79K"
117118
118119 feature_extractor = CLIPFeatureExtractor .from_pretrained (clip_model_id )
119- clip_model = CLIPModel .from_pretrained (clip_model_id )
120+ clip_model = CLIPModel .from_pretrained (clip_model_id , torch_dtype = torch . float16 )
120121
121122 pipeline = DiffusionPipeline .from_pretrained (
122123 "CompVis/stable-diffusion-v1-4" ,
123124 custom_pipeline = "clip_guided_stable_diffusion" ,
124125 clip_model = clip_model ,
125126 feature_extractor = feature_extractor ,
127+ torch_dtype = torch .float16 ,
128+ revision = "fp16" ,
126129 )
130+ pipeline .enable_attention_slicing ()
127131 pipeline = pipeline .to (torch_device )
128132
129133 # NOTE that `"CLIPGuidedStableDiffusion"` is not a class that is defined in the pypi package of th e library, but solely on the community examples folder of GitHub under:
You can’t perform that action at this time.
0 commit comments