Skip to content

Commit e23a66a

Browse files
authored
Merge pull request #260 from cytomining/partition_profiling
Partition column in the metadata during profiling
2 parents 16bc87e + 1a361cc commit e23a66a

File tree

2 files changed

+15
-12
lines changed

2 files changed

+15
-12
lines changed

deepprofiler/__main__.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ def prepare(context):
146146
def sample_sc(context):
147147
if context.parent.obj["config"]["prepare"]["compression"]["implement"]:
148148
context.parent.obj["config"]["paths"]["images"] = context.obj["config"]["paths"]["compressed_images"]
149-
dset = deepprofiler.dataset.image_dataset.read_dataset(context.obj["config"])
149+
dset = deepprofiler.dataset.image_dataset.read_dataset(context.obj["config"], mode='train')
150150
deepprofiler.dataset.sampling.sample_dataset(context.obj["config"], dset)
151151
print("Single-cell sampling complete.")
152152

@@ -159,7 +159,7 @@ def sample_sc(context):
159159
def train(context, epoch, seed):
160160
if context.parent.obj["config"]["prepare"]["compression"]["implement"]:
161161
context.parent.obj["config"]["paths"]["images"] = context.obj["config"]["paths"]["compressed_images"]
162-
dset = deepprofiler.dataset.image_dataset.read_dataset(context.obj["config"])
162+
dset = deepprofiler.dataset.image_dataset.read_dataset(context.obj["config"], mode='train')
163163
deepprofiler.learning.training.learn_model(context.obj["config"], dset, epoch, seed)
164164

165165

@@ -177,8 +177,8 @@ def profile(context, part):
177177
if part >= 0:
178178
partfile = "index-{0:03d}.csv".format(part)
179179
config["paths"]["index"] = context.obj["config"]["paths"]["index"].replace("index.csv", partfile)
180-
metadata = deepprofiler.dataset.image_dataset.read_dataset(context.obj["config"])
181-
deepprofiler.learning.profiling.profile(context.obj["config"], metadata)
180+
dset = deepprofiler.dataset.image_dataset.read_dataset(context.obj["config"], mode='profile')
181+
deepprofiler.learning.profiling.profile(context.obj["config"], dset)
182182

183183

184184
# Auxiliary tool: Split index in multiple parts

deepprofiler/dataset/image_dataset.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,7 @@ def number_of_records(self, dataset):
195195
def add_target(self, new_target):
196196
self.targets.append(new_target)
197197

198-
def read_dataset(config):
198+
def read_dataset(config, mode = 'train'):
199199
# Read metadata and split dataset in training and validation
200200
metadata = deepprofiler.dataset.metadata.Metadata(config["paths"]["index"], dtype=None)
201201
if config["prepare"]["compression"]["implement"]:
@@ -211,10 +211,12 @@ def read_dataset(config):
211211
print(metadata.data.info())
212212

213213
# Split training data
214-
split_field = config["train"]["partition"]["split_field"]
215-
trainingFilter = lambda df: df[split_field].isin(config["train"]["partition"]["training_values"])
216-
validationFilter = lambda df: df[split_field].isin(config["train"]["partition"]["validation_values"])
217-
metadata.splitMetadata(trainingFilter, validationFilter)
214+
if mode == 'train':
215+
split_field = config["train"]["partition"]["split_field"]
216+
trainingFilter = lambda df: df[split_field].isin(config["train"]["partition"]["training_values"])
217+
validationFilter = lambda df: df[split_field].isin(config["train"]["partition"]["validation_values"])
218+
metadata.splitMetadata(trainingFilter, validationFilter)
219+
218220

219221
# Create a dataset
220222
keyGen = lambda r: "{}/{}-{}".format(r["Metadata_Plate"], r["Metadata_Well"], r["Metadata_Site"])
@@ -228,9 +230,10 @@ def read_dataset(config):
228230
)
229231

230232
# Add training targets
231-
for t in config["train"]["partition"]["targets"]:
232-
new_target = deepprofiler.dataset.target.MetadataColumnTarget(t, metadata.data[t].unique())
233-
dset.add_target(new_target)
233+
if mode == 'train':
234+
for t in config["train"]["partition"]["targets"]:
235+
new_target = deepprofiler.dataset.target.MetadataColumnTarget(t, metadata.data[t].unique())
236+
dset.add_target(new_target)
234237

235238
# Activate outlines for masking if needed
236239
if config["dataset"]["locations"]["mask_objects"]:

0 commit comments

Comments
 (0)