Skip to content

Commit cd69b01

Browse files
committed
Added one more "ad hoc" for the Linear.
1 parent b7fc5c0 commit cd69b01

File tree

1 file changed

+13
-0
lines changed

1 file changed

+13
-0
lines changed

pytorch2keras/layers.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -593,6 +593,19 @@ def convert_matmul(params, w_name, scope_name, inputs, layers, weights):
593593

594594
keras_weights = [W]
595595

596+
dense = keras.layers.Dense(
597+
output_channels,
598+
weights=keras_weights, use_bias=False, name=tf_name
599+
)
600+
layers[scope_name] = dense(layers[inputs[0]])
601+
elif len(inputs) == 2:
602+
weights_name = '{0}.weight'.format(w_name)
603+
604+
W = weights[weights_name].numpy().transpose()
605+
input_channels, output_channels = W.shape
606+
607+
keras_weights = [W]
608+
596609
dense = keras.layers.Dense(
597610
output_channels,
598611
weights=keras_weights, use_bias=False, name=tf_name

0 commit comments

Comments
 (0)