Skip to content

Commit 75c1873

Browse files
committed
[NPU] add npu support for waveflow, test=develop
1 parent 0267e5b commit 75c1873

File tree

4 files changed

+21
-8
lines changed

4 files changed

+21
-8
lines changed

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,5 @@ output
44
paddlerec.egg-info/
55
*~
66
*.pyc
7-
*.DS_Store
7+
*.DS_Store
8+
kernel_meta/

models/rank/wide_deep/net.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,13 @@ def __init__(self, sparse_feature_number, sparse_feature_dim,
3535
initializer=paddle.nn.initializer.TruncatedNormal(
3636
mean=0.0, std=1.0 / math.sqrt(self.dense_feature_dim))))
3737

38+
use_sparse = True
39+
if paddle.is_compiled_with_npu():
40+
use_sparse = False
3841
self.embedding = paddle.nn.Embedding(
3942
self.sparse_feature_number,
4043
self.sparse_feature_dim,
41-
sparse=True,
44+
sparse=use_sparse,
4245
weight_attr=paddle.ParamAttr(
4346
name="SparseFeatFactors",
4447
initializer=paddle.nn.initializer.Uniform()))

tools/infer.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ def main(args):
6666
# tools.vars
6767
use_gpu = config.get("runner.use_gpu", True)
6868
use_xpu = config.get("runner.use_xpu", False)
69+
use_npu = config.get("runner.use_npu", False)
6970
use_visual = config.get("runner.use_visual", False)
7071
test_data_dir = config.get("runner.test_data_dir", None)
7172
print_interval = config.get("runner.print_interval", None)
@@ -76,14 +77,18 @@ def main(args):
7677

7778
logger.info("**************common.configs**********")
7879
logger.info(
79-
"use_gpu: {}, use_xpu: {}, use_visual: {}, infer_batch_size: {}, test_data_dir: {}, start_epoch: {}, end_epoch: {}, print_interval: {}, model_load_path: {}".
80-
format(use_gpu, use_xpu, use_visual, infer_batch_size, test_data_dir,
81-
start_epoch, end_epoch, print_interval, model_load_path))
80+
"use_gpu: {}, use_xpu: {}, use_npu: {}, use_visual: {}, infer_batch_size: {}, test_data_dir: {}, start_epoch: {}, end_epoch: {}, print_interval: {}, model_load_path: {}".
81+
format(use_gpu, use_xpu, use_npu, use_visual, infer_batch_size,
82+
test_data_dir, start_epoch, end_epoch, print_interval,
83+
model_load_path))
8284
logger.info("**************common.configs**********")
8385

8486
if use_xpu:
8587
xpu_device = 'xpu:{0}'.format(os.getenv('FLAGS_selected_xpus', 0))
8688
place = paddle.set_device(xpu_device)
89+
elif use_npu:
90+
npu_device = 'npu:{0}'.format(os.getenv('FLAGS_selected_npus', 0))
91+
place = paddle.set_device(npu_device)
8792
else:
8893
place = paddle.set_device('gpu' if use_gpu else 'cpu')
8994

tools/trainer.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ def main(args):
6565

6666
# tools.vars
6767
use_gpu = config.get("runner.use_gpu", True)
68+
use_npu = config.get("runner.use_npu", False)
6869
use_xpu = config.get("runner.use_xpu", False)
6970
use_visual = config.get("runner.use_visual", False)
7071
train_data_dir = config.get("runner.train_data_dir", None)
@@ -77,14 +78,17 @@ def main(args):
7778

7879
logger.info("**************common.configs**********")
7980
logger.info(
80-
"use_gpu: {}, use_xpu: {}, use_visual: {}, train_batch_size: {}, train_data_dir: {}, epochs: {}, print_interval: {}, model_save_path: {}".
81-
format(use_gpu, use_xpu, use_visual, train_batch_size, train_data_dir,
82-
epochs, print_interval, model_save_path))
81+
"use_gpu: {}, use_xpu: {}, use_npu: {}, use_visual: {}, train_batch_size: {}, train_data_dir: {}, epochs: {}, print_interval: {}, model_save_path: {}".
82+
format(use_gpu, use_xpu, use_npu, use_visual, train_batch_size,
83+
train_data_dir, epochs, print_interval, model_save_path))
8384
logger.info("**************common.configs**********")
8485

8586
if use_xpu:
8687
xpu_device = 'xpu:{0}'.format(os.getenv('FLAGS_selected_xpus', 0))
8788
place = paddle.set_device(xpu_device)
89+
elif use_npu:
90+
npu_device = 'npu:{0}'.format(os.getenv('FLAGS_selected_npus', 0))
91+
place = paddle.set_device(npu_device)
8892
else:
8993
place = paddle.set_device('gpu' if use_gpu else 'cpu')
9094

0 commit comments

Comments
 (0)