@@ -132,69 +132,175 @@ class bfloat16 {
132132#endif
133133 }
134134
135- // Increment and decrement operators overloading
135+ bfloat16 &operator +=(const bfloat16 &rhs) {
136+ value = from_float (to_float (value) + to_float (rhs.value ));
137+ return *this ;
138+ }
139+
140+ bfloat16 &operator -=(const bfloat16 &rhs) {
141+ value = from_float (to_float (value) - to_float (rhs.value ));
142+ return *this ;
143+ }
144+
145+ bfloat16 &operator *=(const bfloat16 &rhs) {
146+ value = from_float (to_float (value) * to_float (rhs.value ));
147+ return *this ;
148+ }
149+
150+ bfloat16 &operator /=(const bfloat16 &rhs) {
151+ value = from_float (to_float (value) / to_float (rhs.value ));
152+ return *this ;
153+ }
154+
155+ // Operator ++, --
156+ bfloat16 &operator ++() {
157+ float f = to_float (value);
158+ value = from_float (++f);
159+ return *this ;
160+ }
161+
162+ bfloat16 operator ++(int ) {
163+ bfloat16 ret (*this );
164+ operator ++();
165+ return ret;
166+ }
167+
168+ bfloat16 &operator --() {
169+ float f = to_float (value);
170+ value = from_float (--f);
171+ return *this ;
172+ }
173+
174+ bfloat16 operator --(int ) {
175+ bfloat16 ret (*this );
176+ operator --();
177+ return ret;
178+ }
179+
180+ // Operator +, -, *, /
136181#define OP (op ) \
137- friend bfloat16 &operator op (bfloat16 &lhs) { \
138- float f = to_float (lhs.value ); \
139- lhs.value = from_float (op f); \
140- return lhs; \
141- } \
142- friend bfloat16 operator op (bfloat16 &lhs, int ) { \
143- bfloat16 old = lhs; \
144- operator op (lhs); \
145- return old; \
146- }
147- OP (++)
148- OP (--)
182+ friend bfloat16 operator op (const bfloat16 lhs, const bfloat16 rhs) { \
183+ return to_float (lhs.value ) op to_float (rhs.value ); \
184+ } \
185+ friend double operator op (const bfloat16 lhs, const double rhs) { \
186+ return to_float (lhs.value ) op rhs; \
187+ } \
188+ friend double operator op (const double lhs, const bfloat16 rhs) { \
189+ return lhs op to_float (rhs.value ); \
190+ } \
191+ friend float operator op (const bfloat16 lhs, const float rhs) { \
192+ return to_float (lhs.value ) op rhs; \
193+ } \
194+ friend float operator op (const float lhs, const bfloat16 rhs) { \
195+ return lhs op to_float (rhs.value ); \
196+ } \
197+ friend bfloat16 operator op (const bfloat16 lhs, const int rhs) { \
198+ return to_float (lhs.value ) op rhs; \
199+ } \
200+ friend bfloat16 operator op (const int lhs, const bfloat16 rhs) { \
201+ return lhs op to_float (rhs.value ); \
202+ } \
203+ friend bfloat16 operator op (const bfloat16 lhs, const long rhs) { \
204+ return to_float (lhs.value ) op rhs; \
205+ } \
206+ friend bfloat16 operator op (const long lhs, const bfloat16 rhs) { \
207+ return lhs op to_float (rhs.value ); \
208+ } \
209+ friend bfloat16 operator op (const bfloat16 lhs, const long long rhs) { \
210+ return to_float (lhs.value ) op rhs; \
211+ } \
212+ friend bfloat16 operator op (const long long lhs, const bfloat16 rhs) { \
213+ return lhs op to_float (rhs.value ); \
214+ } \
215+ friend bfloat16 operator op (const bfloat16 &lhs, const unsigned int &rhs) { \
216+ return to_float (lhs.value ) op rhs; \
217+ } \
218+ friend bfloat16 operator op (const unsigned int &lhs, const bfloat16 &rhs) { \
219+ return lhs op to_float (rhs.value ); \
220+ } \
221+ friend bfloat16 operator op (const bfloat16 &lhs, const unsigned long &rhs) { \
222+ return to_float (lhs.value ) op rhs; \
223+ } \
224+ friend bfloat16 operator op (const unsigned long &lhs, const bfloat16 &rhs) { \
225+ return lhs op to_float (rhs.value ); \
226+ } \
227+ friend bfloat16 operator op (const bfloat16 &lhs, \
228+ const unsigned long long &rhs) { \
229+ return to_float (lhs.value ) op rhs; \
230+ } \
231+ friend bfloat16 operator op (const unsigned long long &lhs, \
232+ const bfloat16 &rhs) { \
233+ return lhs op to_float (rhs.value ); \
234+ }
235+ OP (+)
236+ OP (-)
237+ OP (*)
238+ OP (/)
239+
149240#undef OP
150241
151- // Assignment operators overloading
242+ // Operator ==, !=, <, >, <=, >=
152243#define OP (op ) \
153- friend bfloat16 &operator op (bfloat16 &lhs, const bfloat16 &rhs) { \
154- float f = static_cast <float >(lhs); \
155- f op static_cast <float >(rhs); \
156- return lhs = f; \
157- } \
158- template <typename T> \
159- friend bfloat16 &operator op (bfloat16 &lhs, const T &rhs) { \
160- float f = static_cast <float >(lhs); \
161- f op static_cast <float >(rhs); \
162- return lhs = f; \
163- } \
164- template <typename T> friend T &operator op (T &lhs, const bfloat16 &rhs) { \
165- float f = static_cast <float >(lhs); \
166- f op static_cast <float >(rhs); \
167- return lhs = f; \
168- }
169- OP (+=)
170- OP (-=)
171- OP (*=)
172- OP (/=)
173- #undef OP
244+ friend bool operator op (const bfloat16 &lhs, const bfloat16 &rhs) { \
245+ return to_float (lhs.value ) op to_float (rhs.value ); \
246+ } \
247+ friend bool operator op (const bfloat16 &lhs, const double &rhs) { \
248+ return to_float (lhs.value ) op rhs; \
249+ } \
250+ friend bool operator op (const double &lhs, const bfloat16 &rhs) { \
251+ return lhs op to_float (rhs.value ); \
252+ } \
253+ friend bool operator op (const bfloat16 &lhs, const float &rhs) { \
254+ return to_float (lhs.value ) op rhs; \
255+ } \
256+ friend bool operator op (const float &lhs, const bfloat16 &rhs) { \
257+ return lhs op to_float (rhs.value ); \
258+ } \
259+ friend bool operator op (const bfloat16 &lhs, const int &rhs) { \
260+ return to_float (lhs.value ) op rhs; \
261+ } \
262+ friend bool operator op (const int &lhs, const bfloat16 &rhs) { \
263+ return lhs op to_float (rhs.value ); \
264+ } \
265+ friend bool operator op (const bfloat16 &lhs, const long &rhs) { \
266+ return to_float (lhs.value ) op rhs; \
267+ } \
268+ friend bool operator op (const long &lhs, const bfloat16 &rhs) { \
269+ return lhs op to_float (rhs.value ); \
270+ } \
271+ friend bool operator op (const bfloat16 &lhs, const long long &rhs) { \
272+ return to_float (lhs.value ) op rhs; \
273+ } \
274+ friend bool operator op (const long long &lhs, const bfloat16 &rhs) { \
275+ return lhs op to_float (rhs.value ); \
276+ } \
277+ friend bool operator op (const bfloat16 &lhs, const unsigned int &rhs) { \
278+ return to_float (lhs.value ) op rhs; \
279+ } \
280+ friend bool operator op (const unsigned int &lhs, const bfloat16 &rhs) { \
281+ return lhs op to_float (rhs.value ); \
282+ } \
283+ friend bool operator op (const bfloat16 &lhs, const unsigned long &rhs) { \
284+ return to_float (lhs.value ) op rhs; \
285+ } \
286+ friend bool operator op (const unsigned long &lhs, const bfloat16 &rhs) { \
287+ return lhs op to_float (rhs.value ); \
288+ } \
289+ friend bool operator op (const bfloat16 &lhs, \
290+ const unsigned long long &rhs) { \
291+ return to_float (lhs.value ) op rhs; \
292+ } \
293+ friend bool operator op (const unsigned long long &lhs, \
294+ const bfloat16 &rhs) { \
295+ return lhs op to_float (rhs.value ); \
296+ }
297+ OP (==)
298+ OP (!=)
299+ OP (<)
300+ OP (>)
301+ OP (<=)
302+ OP (>=)
174303
175- // Binary operators overloading
176- #define OP (type, op ) \
177- friend type operator op (const bfloat16 &lhs, const bfloat16 &rhs) { \
178- return type{static_cast <float >(lhs) op static_cast <float >(rhs)}; \
179- } \
180- template <typename T> \
181- friend type operator op (const bfloat16 &lhs, const T &rhs) { \
182- return type{static_cast <float >(lhs) op static_cast <float >(rhs)}; \
183- } \
184- template <typename T> \
185- friend type operator op (const T &lhs, const bfloat16 &rhs) { \
186- return type{static_cast <float >(lhs) op static_cast <float >(rhs)}; \
187- }
188- OP (bfloat16, +)
189- OP (bfloat16, -)
190- OP (bfloat16, *)
191- OP (bfloat16, /)
192- OP (bool , ==)
193- OP (bool , !=)
194- OP (bool , <)
195- OP (bool , >)
196- OP (bool , <=)
197- OP (bool , >=)
198304#undef OP
199305
200306 // Bitwise(|,&,~,^), modulo(%) and shift(<<,>>) operations are not supported
0 commit comments