Skip to content

Commit 377ef36

Browse files
added check to skip regression if zero instances of the class are found
1 parent aba01e0 commit 377ef36

File tree

1 file changed

+3
-11
lines changed

1 file changed

+3
-11
lines changed

AFL/double_agent/TreePipeline.py

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -62,19 +62,14 @@ def __init__(self, input_variable, output_variable, key_variable, morphology, mo
6262
def calculate(self, dataset):
6363
data = self._get_variable(dataset)
6464
key = dataset[self.key_variable].data
65-
print(np.unique(key))
66-
print(self.morphology)
6765
inds = np.where(np.equal(key, self.morphology))[0]
68-
predictions = self.regression.predict(data[inds])
6966
if self.output_variable in dataset.data_vars:
7067
output = dataset[self.output_variable].data
7168
else:
7269
output = np.nan * np.ones(data.shape[0])
73-
print("INDS")
74-
print(inds.shape)
75-
print("PREDS")
76-
print(predictions.shape)
77-
output[inds] = predictions.reshape(-1)
70+
if len(inds) > 0:
71+
predictions = self.regression.predict(data[inds])
72+
output[inds] = predictions.reshape(-1)
7873
dataset[self.output_variable] = ('sample', output)
7974
return(self)
8075

@@ -91,12 +86,9 @@ def calculate(self, dataset):
9186
labs = []
9287
for i in range(data.shape[0]):
9388
d = data.data[i]
94-
print(d)
95-
print(type(d))
9689
comps = self.components[d]
9790
measures = np.array([dataset[c].data[i] for c in comps])
9891
portions = measures/np.sum(measures)
99-
print(np.where(portions > self.threshold)[0])
10092
if any(portions >= self.threshold):
10193
labs += [comps[np.where(portions >= self.threshold)[0][0]]]
10294
else:

0 commit comments

Comments
 (0)