Skip to content

Commit 0aabcd3

Browse files
authored
Merge pull request #717 from windstamp/npu_dev
[NPU] fix for ddn, ffm, fm
2 parents 7139b6c + dee2b31 commit 0aabcd3

File tree

3 files changed

+17
-5
lines changed

3 files changed

+17
-5
lines changed

models/rank/dnn/net.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,14 @@ def __init__(self,
3434
self.num_field = num_field
3535
self.layer_sizes = layer_sizes
3636

37+
use_sparse = True
38+
if paddle.is_compiled_with_npu():
39+
use_sparse = False
40+
3741
self.embedding = paddle.nn.Embedding(
3842
self.sparse_feature_number,
3943
self.sparse_feature_dim,
40-
sparse=True,
44+
sparse=use_sparse,
4145
weight_attr=paddle.ParamAttr(
4246
name="SparseFeatFactors",
4347
initializer=paddle.nn.initializer.Uniform()))

models/rank/ffm/net.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,11 +57,15 @@ def __init__(self, sparse_feature_number, sparse_feature_dim,
5757
self.sparse_num_field = sparse_num_field
5858
self.init_value_ = 0.1
5959

60+
use_sparse = True
61+
if paddle.is_compiled_with_npu():
62+
use_sparse = False
63+
6064
# sparse part coding
6165
self.embedding_one = paddle.nn.Embedding(
6266
sparse_feature_number,
6367
1,
64-
sparse=True,
68+
sparse=use_sparse,
6569
weight_attr=paddle.ParamAttr(
6670
initializer=paddle.nn.initializer.TruncatedNormal(
6771
mean=0.0,
@@ -71,7 +75,7 @@ def __init__(self, sparse_feature_number, sparse_feature_dim,
7175
self.embedding = paddle.nn.Embedding(
7276
self.sparse_feature_number,
7377
self.sparse_feature_dim * self.sparse_num_field,
74-
sparse=True,
78+
sparse=use_sparse,
7579
weight_attr=paddle.ParamAttr(
7680
initializer=paddle.nn.initializer.TruncatedNormal(
7781
mean=0.0,

models/rank/fm/net.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,11 +56,15 @@ def __init__(self, sparse_feature_number, sparse_feature_dim,
5656
self.sparse_num_field = sparse_num_field
5757
self.init_value_ = 0.1
5858

59+
use_sparse = True
60+
if paddle.is_compiled_with_npu():
61+
use_sparse = False
62+
5963
# sparse part coding
6064
self.embedding_one = paddle.nn.Embedding(
6165
sparse_feature_number,
6266
1,
63-
sparse=True,
67+
sparse=use_sparse,
6468
weight_attr=paddle.ParamAttr(
6569
initializer=paddle.nn.initializer.TruncatedNormal(
6670
mean=0.0,
@@ -70,7 +74,7 @@ def __init__(self, sparse_feature_number, sparse_feature_dim,
7074
self.embedding = paddle.nn.Embedding(
7175
self.sparse_feature_number,
7276
self.sparse_feature_dim,
73-
sparse=True,
77+
sparse=use_sparse,
7478
weight_attr=paddle.ParamAttr(
7579
initializer=paddle.nn.initializer.TruncatedNormal(
7680
mean=0.0,

0 commit comments

Comments
 (0)