@@ -98,6 +98,12 @@ def locations(out_dir, metadata, config, make_struct):
9898 })
9999 locs .to_csv (path , index = False )
100100
101+
102+ @pytest .fixture (scope = "function" )
103+ def gpu ():
104+ return '0'
105+
106+
101107@pytest .fixture (scope = "function" )
102108def checkpoint (config , dataset ):
103109 crop_generator = importlib .import_module (
@@ -107,20 +113,20 @@ def checkpoint(config, dataset):
107113 "plugins.crop_generators.{}" .format (config ["train" ]["model" ]["crop_generator" ])) \
108114 .SingleImageGeneratorClass
109115 dpmodel = importlib .import_module ("plugins.models.{}" .format (config ["train" ]["model" ]["name" ])) \
110- .ModelClass (config , dataset , crop_generator , profile_crop_generator )
116+ .ModelClass (config , dataset , gpu , crop_generator , profile_crop_generator )
111117 dpmodel .feature_model .compile (dpmodel .optimizer , dpmodel .loss )
112118 filename = os .path .join (config ["paths" ]["checkpoints" ], config ["profile" ]["checkpoint" ])
113119 dpmodel .feature_model .save_weights (filename )
114120 return filename
115121
116122
117123@pytest .fixture (scope = "function" )
118- def profile (config , dataset ):
119- return deepprofiler .learning .profiling .Profile (config , dataset )
124+ def profile (config , dataset , gpu ):
125+ return deepprofiler .learning .profiling .Profile (config , dataset , gpu )
120126
121127
122- def test_init (config , dataset ):
123- prof = deepprofiler .learning .profiling .Profile (config , dataset )
128+ def test_init (config , dataset , gpu ):
129+ prof = deepprofiler .learning .profiling .Profile (config , dataset , gpu )
124130 test_num_channels = len (config ["dataset" ]["images" ]["channels" ])
125131 assert prof .config == config
126132 assert prof .dset == dataset
@@ -152,8 +158,8 @@ def test_extract_features(profile, metadata, locations, checkpoint):
152158 assert os .path .isfile (output_file )
153159
154160
155- def test_profile (config , dataset , data , locations , checkpoint ):
156- deepprofiler .learning .profiling .profile (config , dataset )
161+ def test_profile (config , dataset , data , locations , checkpoint , gpu ):
162+ deepprofiler .learning .profiling .profile (config , dataset , gpu )
157163 for index , row in dataset .meta .data .iterrows ():
158164 output_file = config ["paths" ]["features" ] + "/{}_{}_{}.npz" \
159165 .format (row ["Metadata_Plate" ], row ["Metadata_Well" ], row ["Metadata_Site" ])
0 commit comments