Skip to content

Commit 7690b1f

Browse files
committed
Fixing weights
1 parent fcaf51b commit 7690b1f

File tree

1 file changed

+12
-9
lines changed

1 file changed

+12
-9
lines changed

bigml/fusion.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,8 @@ def get_models_weight(models_info):
104104
else:
105105
model_ids = models_info
106106
weights = None
107+
if weights is None:
108+
weights = [1] * len(model_ids)
107109
return model_ids, weights
108110
except KeyError:
109111
raise ValueError("Failed to find the models in the fusion info.")
@@ -289,10 +291,9 @@ def predict_probability(self, input_data,
289291
continue
290292
if self.regression:
291293
prediction = prediction[0]
292-
if self.weights is not None:
293-
weights.append(1 if not self.weights else self.weights[
294-
self.model_ids.index(model.resource_id)])
295-
prediction = self.weigh(prediction, model.resource_id)
294+
weights.append(self.weights[self.model_ids.index(
295+
model.resource_id)])
296+
prediction = self.weigh(prediction, model.resource_id)
296297
# we need to check that all classes in the fusion
297298
# are also in the composing model
298299
if not self.regression and \
@@ -312,7 +313,8 @@ def predict_probability(self, input_data,
312313
total_weight = sum(weights)
313314
for index, pred in enumerate(votes.predictions):
314315
prediction += pred # the weight is already considered in pred
315-
prediction /= float(total_weight)
316+
if total_weight > 0:
317+
prediction /= float(total_weight)
316318
if compact:
317319
output = [prediction]
318320
else:
@@ -378,8 +380,8 @@ def predict_confidence(self, input_data,
378380
# are found and Linear Regressions have no confidence
379381
continue
380382
predictions.append(prediction)
381-
weights.append(1 if not self.weights else self.weights[
382-
self.model_ids.index(model.resource_id)])
383+
weights.append(self.weights[self.model_ids.index(
384+
model.resource_id)])
383385
if self.regression:
384386
prediction = prediction["prediction"]
385387
if self.regression:
@@ -389,8 +391,9 @@ def predict_confidence(self, input_data,
389391
for index, pred in enumerate(predictions):
390392
prediction += pred.get("prediction") * weights[index]
391393
confidence += pred.get("confidence")
392-
prediction /= float(total_weight)
393-
confidence /= float(len(predictions))
394+
if total_weight > 0:
395+
prediction /= float(total_weight)
396+
confidence /= float(len(predictions))
394397
if compact:
395398
output = [prediction, confidence]
396399
else:

0 commit comments

Comments
 (0)