Skip to content

Commit 7b997cd

Browse files
committed
merge
2 parents b673a26 + 35ff995 commit 7b997cd

File tree

2 files changed

+13
-21
lines changed

2 files changed

+13
-21
lines changed

examples/keras_integration.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import numpy as np
88
from keras.datasets import mnist
99
from keras.models import Sequential
10-
from keras.layers import Conv2D, MaxPooling2D, Dropout, Flatten, Dense
10+
from keras.layers import Dense, Dropout, Flatten, Conv2D, MaxPooling2D
1111
from keras.wrappers.scikit_learn import KerasClassifier
1212
from modAL.models import ActiveLearner
1313

@@ -18,8 +18,10 @@ def create_keras_model():
1818
This function compiles and returns a Keras model.
1919
Should be passed to KerasClassifier in the Keras scikit-learn API.
2020
"""
21+
2122
model = Sequential()
2223
model.add(Conv2D(32, kernel_size=(3, 3), activation='relu', input_shape=(28, 28, 1)))
24+
model.add(Conv2D(64, (3, 3), activation='relu'))
2325
model.add(MaxPooling2D(pool_size=(2, 2)))
2426
model.add(Dropout(0.25))
2527
model.add(Flatten())
@@ -68,20 +70,21 @@ def create_keras_model():
6870
learner = ActiveLearner(
6971
estimator=classifier,
7072
X_training=X_initial, y_training=y_initial,
71-
verbose=0
73+
verbose=1
7274
)
7375

7476
# the active learning loop
7577
n_queries = 10
7678
for idx in range(n_queries):
7779
query_idx, query_instance = learner.query(X_pool, n_instances=200, verbose=0)
80+
print(query_idx)
7881
learner.teach(
7982
X=X_pool[query_idx], y=y_pool[query_idx],
80-
verbose=0
83+
verbose=1
8184
)
8285
# remove queried instance from pool
8386
X_pool = np.delete(X_pool, query_idx, axis=0)
8487
y_pool = np.delete(y_pool, query_idx, axis=0)
8588

8689
# the final accuracy score
87-
print(learner.score(X_test, y_test, verbose=0))
90+
print(learner.score(X_test, y_test, verbose=1))

modAL/utils/combination.py

Lines changed: 6 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,6 @@
11
import numpy as np
22

33

4-
def check_init(obj, typ):
5-
if obj is None:
6-
return typ()
7-
else:
8-
assert isinstance(obj, typ), 'obj must be of type %s' % typ.__name__
9-
return obj
10-
11-
124
def make_linear_combination(*functions, weights=None):
135
"""
146
Takes the given functions and makes a function which returns the linear combination
@@ -38,8 +30,8 @@ def make_linear_combination(*functions, weights=None):
3830
'same as the number of given functions'
3931

4032
def linear_combination(*args, **kwargs):
41-
return np.sum([weights[i]*functions[i](*args, **kwargs)
42-
for i in range(len(weights))], axis=0)
33+
return np.sum((weights[i]*functions[i](*args, **kwargs)
34+
for i in range(len(weights))), axis=0)
4335

4436
return linear_combination
4537

@@ -79,15 +71,12 @@ def product_function(*args, **kwargs):
7971
return product_function
8072

8173

82-
def make_query_strategy(utility_measure, selector, utility_kwargs, selector_kwargs):
83-
84-
utility_kwargs, selector_kwargs = check_init(utility_kwargs, dict), check_init(selector_kwargs, dict)
85-
74+
def make_query_strategy(utility_measure, selector):
8675
# TODO: check for the signatures of utility_measure and selector
8776

88-
def query_strategy(classifier, X, utility_kwargs, selector_kwargs):
89-
utility = utility_measure(classifier, X, **utility_kwargs)
90-
query_idx, query_instance = selector(utility, **selector_kwargs)
77+
def query_strategy(classifier, X):
78+
utility = utility_measure(classifier, X)
79+
query_idx, query_instance = selector(utility)
9180
return query_idx, query_instance
9281

9382
return query_strategy

0 commit comments

Comments
 (0)