Skip to content

Commit a89588f

Browse files
UbuntuUbuntu
authored andcommitted
map config.architecture string to RNN Cell
1 parent 606e0fa commit a89588f

File tree

3 files changed

+14
-11
lines changed

3 files changed

+14
-11
lines changed

examples/pytorch/FastCells/train_classifier.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -414,10 +414,10 @@ def create_model(model_config, input_size, num_keywords):
414414
wSparsity_list = [model_config.wSparsity, model_config.wSparsity, model_config.wSparsity]
415415
uSparsity_list = [model_config.uSparsity, model_config.uSparsity, model_config.uSparsity]
416416
print(model_config.gate_nonlinearity, model_config.update_nonlinearity)
417-
return ModelClass(input_size, model_config.num_layers, hidden_units_list,
418-
wRank_list, uRank_list, wSparsity_list, uSparsity_list,
419-
model_config.gate_nonlinearity, model_config.update_nonlinearity,
420-
num_keywords)
417+
return ModelClass(model_config.architecture, input_size, model_config.num_layers,
418+
hidden_units_list, wRank_list, uRank_list, wSparsity_list,
419+
uSparsity_list, model_config.gate_nonlinearity,
420+
model_config.update_nonlinearity, num_keywords)
421421

422422
def save_json(obj, filename):
423423
with open(filename, "w") as f:

pytorch/edgeml_pytorch/graph/rnn.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1100,7 +1100,9 @@ def forward(self, input, hiddenState=None, cellState=None):
11001100
class FastGRNNCUDA(nn.Module):
11011101
"""Unrolled implementation of the FastGRNNCUDACell"""
11021102
def __init__(self, input_size, hidden_size, gate_nonlinearity="sigmoid",
1103-
update_nonlinearity="tanh", wRank=None, uRank=None, zetaInit=1.0, nuInit=-4.0, name="FastGRNNCUDACell"):
1103+
update_nonlinearity="tanh", wRank=None, uRank=None,
1104+
wSparsity=1.0, uSparsity=1.0, zetaInit=1.0, nuInit=-4.0,
1105+
name="FastGRNNCUDACell"):
11041106
super(FastGRNNCUDA, self).__init__()
11051107
if utils.findCUDA() is None:
11061108
raise Exception('FastGRNNCUDA is supported only on GPU devices.')
@@ -1309,4 +1311,4 @@ def forward(ctx, input, bias_gate, bias_update, zeta, nu, old_h, w, u, w1, w2, u
13091311
def backward(ctx, grad_h):
13101312
outputs = fastgrnn_cuda.backward_unroll(
13111313
grad_h.contiguous(), *ctx.saved_variables, ctx.gate_non_linearity)
1312-
return tuple(outputs + [None])
1314+
return tuple(outputs + [None])

pytorch/edgeml_pytorch/trainer/fastmodel.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,10 @@ class RNNClassifierModel(inheritance_class):
1414
RNN-based classifier
1515
"""
1616

17-
def __init__(self, input_dim, num_layers, hidden_units_list,
17+
def __init__(self, rnn_name, input_dim, num_layers, hidden_units_list,
1818
wRank_list, uRank_list, wSparsity_list, uSparsity_list,
19-
gate_nonlinearity, update_nonlinearity,
20-
num_classes=None, linear=True, batch_first=False, apply_softmax=True):
19+
gate_nonlinearity, update_nonlinearity, num_classes=None,
20+
linear=True, batch_first=False, apply_softmax=True):
2121
"""
2222
Initialize the KeywordSpotter with the following parameters:
2323
input_dim - the size of the input audio frame in # samples.
@@ -45,8 +45,9 @@ def __init__(self, input_dim, num_layers, hidden_units_list,
4545

4646
super(RNNClassifierModel, self).__init__()
4747

48+
RNN = getattr(getattr(getattr(__import__('edgeml_pytorch'), 'graph'), 'rnn'), rnn_name)
4849
self.rnn_list = nn.ModuleList([
49-
FastGRNN(self.input_dim if l==0 else self.hidden_units_list[l-1],
50+
RNN(self.input_dim if l==0 else self.hidden_units_list[l-1],
5051
self.hidden_units_list[l],
5152
gate_nonlinearity=self.gate_nonlinearity,
5253
update_nonlinearity=self.update_nonlinearity,
@@ -144,4 +145,4 @@ def forward(self, input):
144145
if self.apply_softmax:
145146
model_output = F.log_softmax(model_output, dim=1)
146147
return model_output
147-
return RNNClassifierModel
148+
return RNNClassifierModel

0 commit comments

Comments
 (0)