1+ #-------------------------------------------------------------
2+ #
3+ # Licensed to the Apache Software Foundation (ASF) under one
4+ # or more contributor license agreements. See the NOTICE file
5+ # distributed with this work for additional information
6+ # regarding copyright ownership. The ASF licenses this file
7+ # to you under the Apache License, Version 2.0 (the
8+ # "License"); you may not use this file except in compliance
9+ # with the License. You may obtain a copy of the License at
10+ #
11+ # http://www.apache.org/licenses/LICENSE-2.0
12+ #
13+ # Unless required by applicable law or agreed to in writing,
14+ # software distributed under the License is distributed on an
15+ # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
16+ # KIND, either express or implied. See the License for the
17+ # specific language governing permissions and limitations
18+ # under the License.
19+ #
20+ #-------------------------------------------------------------
21+
22+ source("nn/layers/embedding.dml") as embedding
23+ source("src/test/scripts/applications/nn/util.dml") as test_util
24+
25+ embedding_test_forward = function() {
26+ print("Testing Embedding - Forward Test")
27+ n = 4
28+ v = 7
29+ d = 3
30+
31+ embedding_dict = matrix("-0.78327566 -0.87246466 -0.80580276
32+ -0.17845497 2.1740944 -1.2514428
33+ -0.27202556 -1.3681601 -1.5384313
34+ 1.4215976 -0.463162 1.2592019
35+ -1.7417 -0.46109396 -0.06011621
36+ -0.7803316 1.0802858 0.7465289
37+ 0. 0. 0.", rows=v, cols=d)
38+ indices = matrix("1 6 7 6", rows=n, cols=1)
39+
40+ embeddings = embedding::forward(indices, embedding_dict)
41+
42+ expected_embeddings = matrix("-0.78327566 -0.87246466 -0.80580276
43+ -0.7803316 1.0802858 0.7465289
44+ 0. 0. 0.
45+ -0.7803316 1.0802858 0.7465289", rows=n, cols=d)
46+
47+ test_util::check_all_close(embeddings, expected_embeddings, 1e-05)
48+ }
49+
50+ embedding_test_forward_backward_no_pad = function() {
51+ print("Testing Embedding - Forward & Backward Test w/out Padding")
52+ n = 2
53+ v = 4
54+ d = 3
55+
56+ embedding_dict = matrix("-0.15039968 0.56168836 -0.577436
57+ 0.47334725 1.5215642 -0.1924941
58+ 1.600819 -1.1331359 -2.58817
59+ 0.9779929 -0.82212716 -1.5917081", rows=v, cols=d)
60+ indices = matrix("2 3", rows=n, cols=1)
61+
62+ embeddings = embedding::forward(indices, embedding_dict)
63+
64+ expected_embeddings = matrix("0.47334725 1.5215642 -0.1924941
65+ 1.600819 -1.1331359 -2.58817", rows=n, cols=d)
66+
67+ test_util::check_all_close(embeddings, expected_embeddings, 1e-05)
68+
69+ dout = matrix(seq(1, n*d, 1), rows=n, cols=d)
70+ padding_idx = -1
71+
72+ dembedding_dict = embedding::backward(dout, indices, v, padding_idx)
73+ expected_dembedding_dict = matrix("0. 0. 0.
74+ 1. 2. 3.
75+ 4. 5. 6.
76+ 0. 0. 0.", rows=v, cols=d)
77+ test_util::check_all_close(dembedding_dict, expected_dembedding_dict, 1e-05)
78+ }
79+
80+ embedding_test_forward_backward_pad = function() {
81+ print("Testing Embedding - Forward & Backward Test w/ Padding")
82+ n = 5
83+ v = 10
84+ d = 6
85+
86+ embedding_dict = matrix("-1.24377859e+00 -1.10724878e+00 2.35533118e-01 6.65530920e-01
87+ 9.80555452e-03 6.31030917e-01
88+ 8.16493928e-01 -6.21011078e-01 -5.75569510e-01 -3.93419750e-02
89+ -6.20878041e-01 1.37852756e-02
90+ 7.43950903e-01 1.60437262e+00 -2.31788456e-01 1.15943216e-01
91+ -8.83608997e-01 1.11547875e+00
92+ 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00
93+ 0.00000000e+00 0.00000000e+00
94+ 1.70598769e+00 1.82770026e+00 1.30581510e+00 1.05738208e-01
95+ 4.50116873e-01 3.48498315e-01
96+ 1.40551448e+00 3.43091488e-02 1.84714049e-03 -5.52828193e-01
97+ 3.65064174e-01 -9.31223869e-01
98+ 1.33713937e+00 -3.43729639e+00 -1.22915792e+00 -1.12923630e-01
99+ -1.16292477e+00 -2.16708351e-02
100+ 6.63879395e-01 -2.76697308e-01 -9.02738094e-01 -6.85515344e-01
101+ -6.43863618e-01 -2.30419707e+00
102+ 1.44121364e-01 5.20578504e-01 -6.53087497e-01 6.62900746e-01
103+ 3.82369667e-01 -2.25386508e-02
104+ 2.20637798e+00 -6.86733365e-01 -1.27398467e+00 6.28316283e-01
105+ 2.70236313e-01 2.20882833e-01", rows=v, cols=d)
106+ indices = matrix("1 1 1 4 6", rows=n, cols=1)
107+
108+ embeddings = embedding::forward(indices, embedding_dict)
109+
110+ expected_embeddings = matrix("-1.2437786 -1.1072488 0.23553312 0.6655309 0.00980555 0.6310309
111+ -1.2437786 -1.1072488 0.23553312 0.6655309 0.00980555 0.6310309
112+ -1.2437786 -1.1072488 0.23553312 0.6655309 0.00980555 0.6310309
113+ 0. 0. 0. 0. 0. 0.
114+ 1.4055145 0.03430915 0.00184714 -0.5528282 0.36506417 -0.93122387", rows=n, cols=d)
115+
116+ test_util::check_all_close(embeddings, expected_embeddings, 1e-05)
117+
118+ dout = matrix(seq(1, n*d, 1), rows=n, cols=d)
119+ padding_idx = 4
120+
121+ dembedding_dict = embedding::backward(dout, indices, v, padding_idx)
122+ expected_dembedding_dict = matrix("21. 24. 27. 30. 33. 36.
123+ 0. 0. 0. 0. 0. 0.
124+ 0. 0. 0. 0. 0. 0.
125+ 0. 0. 0. 0. 0. 0.
126+ 0. 0. 0. 0. 0. 0.
127+ 25. 26. 27. 28. 29. 30.
128+ 0. 0. 0. 0. 0. 0.
129+ 0. 0. 0. 0. 0. 0.
130+ 0. 0. 0. 0. 0. 0.
131+ 0. 0. 0. 0. 0. 0.", rows=v, cols=d)
132+ test_util::check_all_close(dembedding_dict, expected_dembedding_dict, 1e-05)
133+ }
134+
135+ embedding_test_forward()
136+ embedding_test_forward_backward_no_pad()
137+ embedding_test_forward_backward_pad()
0 commit comments