Skip to content

Commit 1c9d97e

Browse files
authored
[NEP50] Implement NumPy 2.x type promotion for unsigned array + signed scalar (#572)
Implements NEP 50 type promotion (NumPy 2.x) for unsigned array + signed scalar combinations. Changes: - 12 entries updated in _typemap_arr_scalar - 59 tests with 100% coverage - Comprehensive documentation added Closes #529
1 parent 602b9fc commit 1c9d97e

File tree

3 files changed

+1153
-12
lines changed

3 files changed

+1153
-12
lines changed

src/NumSharp.Core/Logic/np.find_common_type.cs

Lines changed: 160 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,82 @@
88

99
namespace NumSharp
1010
{
11+
// ================================================================================
12+
// TYPE PROMOTION SYSTEM
13+
// ================================================================================
14+
//
15+
// This file implements NumPy-compatible type promotion for arithmetic operations.
16+
// When two arrays (or an array and a scalar) are combined, this system determines
17+
// the result dtype.
18+
//
19+
// ARCHITECTURE
20+
// ============
21+
//
22+
// Four lookup tables are used (two pairs for Type and NPTypeCode access):
23+
//
24+
// _typemap_arr_arr / _nptypemap_arr_arr - Array + Array promotion
25+
// _typemap_arr_scalar / _nptypemap_arr_scalar - Array + Scalar promotion
26+
//
27+
// The tables are FrozenDictionary<(T1, T2), TResult> for O(1) lookup.
28+
//
29+
// WHEN EACH TABLE IS USED
30+
// =======================
31+
//
32+
// The _FindCommonType(NDArray, NDArray) method decides which table to use:
33+
//
34+
// if (both are non-scalar arrays) → _typemap_arr_arr
35+
// if (both are scalar arrays) → _FindCommonScalarType (uses arr_arr rules)
36+
// if (one is array, one is scalar) → _typemap_arr_scalar
37+
//
38+
// This matters because scalar promotion follows different rules than array promotion.
39+
//
40+
// KIND HIERARCHY
41+
// ==============
42+
//
43+
// Types are grouped into "kinds" with a promotion hierarchy:
44+
//
45+
// boolean < integer < floating-point < complex
46+
//
47+
// When operands are of different kinds, the result promotes to the higher kind:
48+
//
49+
// int32 + float32 → float64 (int promotes to float)
50+
// float32 + complex → complex (float promotes to complex)
51+
//
52+
// WITHIN-KIND PROMOTION
53+
// =====================
54+
//
55+
// When operands are the same kind, promotion depends on the operation type:
56+
//
57+
// Array + Array (both non-scalar):
58+
// - Result is the "larger" type that can hold both ranges
59+
// - uint8 + int16 → int16 (int16 can hold uint8 range + negatives)
60+
// - uint32 + int32 → int64 (need 64-bit to hold both ranges)
61+
// - uint64 + int64 → float64 (no integer type can hold both!)
62+
//
63+
// Array + Scalar (NEP 50 behavior):
64+
// - Array dtype wins when scalar is same-kind (e.g., both integers)
65+
// - uint8_array + int32_scalar → uint8 (array wins)
66+
// - float32_array + int32_scalar → float32 (array wins, same effective kind)
67+
//
68+
// EXAMPLES
69+
// ========
70+
//
71+
// var a = np.array(new byte[] {1, 2, 3}); // uint8
72+
// var b = np.array(new int[] {4, 5, 6}); // int32
73+
//
74+
// (a + b).dtype == np.int32 // arr+arr: promotes to int32
75+
// (a + 5).dtype == np.uint8 // arr+scalar: array wins (NEP 50)
76+
// (a + 5.0).dtype == np.float64 // cross-kind: float wins
77+
//
78+
// REFERENCES
79+
// ==========
80+
//
81+
// - NumPy type promotion: https://numpy.org/doc/stable/reference/ufuncs.html#type-casting-rules
82+
// - NEP 50 (scalar promotion): https://numpy.org/neps/nep-0050-scalar-promotion.html
83+
// - Array API type promotion: https://data-apis.org/array-api/latest/API_specification/type_promotion.html
84+
//
85+
// ================================================================================
86+
1187
[SuppressMessage("ReSharper", "StaticMemberInitializerReferesToMemberBelow")]
1288
public static partial class np
1389
{
@@ -50,6 +126,39 @@ static np()
50126
{
51127
#region arr_arr
52128

129+
// ============================================================================
130+
// ARRAY-ARRAY TYPE PROMOTION TABLE
131+
// ============================================================================
132+
//
133+
// This table defines type promotion when TWO ARRAYS are combined.
134+
// The key is (LeftArrayType, RightArrayType), the value is the result type.
135+
//
136+
// PROMOTION RULES:
137+
//
138+
// 1. Same type: result is that type
139+
// int32 + int32 → int32
140+
//
141+
// 2. Same kind, different size: result is larger type
142+
// int16 + int32 → int32
143+
// float32 + float64 → float64
144+
//
145+
// 3. Signed + Unsigned (same size): result is next-larger signed type
146+
// int16 + uint16 → int32 (need more bits for both ranges)
147+
// int32 + uint32 → int64
148+
// int64 + uint64 → float64 (no larger integer exists!)
149+
//
150+
// 4. Cross-kind: result is the higher kind
151+
// int32 + float32 → float64 (int32 needs float64 precision)
152+
// uint8 + float32 → float32 (uint8 fits in float32)
153+
//
154+
// 5. Complex: absorbs everything
155+
// float32 + complex64 → complex64
156+
// int32 + complex64 → complex128 (int32 needs float64 precision)
157+
//
158+
// This table matches NumPy 2.x arr+arr behavior exactly.
159+
//
160+
// ============================================================================
161+
53162
var typemap_arr_arr = new Dictionary<(Type, Type), Type>(180);
54163
typemap_arr_arr.Add((np.@bool, np.@bool), np.@bool);
55164
typemap_arr_arr.Add((np.@bool, np.uint8), np.uint8);
@@ -243,6 +352,45 @@ static np()
243352

244353
#region arr_scalar
245354

355+
// ============================================================================
356+
// ARRAY-SCALAR TYPE PROMOTION TABLE
357+
// ============================================================================
358+
//
359+
// This table defines type promotion when an array operates with a scalar value.
360+
// The key is (ArrayType, ScalarType), the value is the result type.
361+
//
362+
// NUMSHARP DESIGN DECISION:
363+
// C# primitive scalars (int, short, long, etc.) are treated as "weakly typed"
364+
// like Python scalars in NumPy 2.x, NOT like NumPy scalars (np.int32, etc.).
365+
//
366+
// This means: np.array(new byte[]{1,2,3}) + 5 → uint8 result (not int32)
367+
//
368+
// WHY: This matches the natural Python/NumPy user experience where `arr + 5`
369+
// preserves the array's dtype when both are integers. This is consistent with
370+
// NumPy 2.x behavior under NEP 50 for Python scalar operands.
371+
//
372+
// NEP 50 (NumPy Enhancement Proposal 50):
373+
// https://numpy.org/neps/nep-0050-scalar-promotion.html
374+
//
375+
// Key rule: When an array operates with a scalar of the same "kind" (e.g., both
376+
// are integers), the array dtype wins. Cross-kind operations (int + float) still
377+
// promote to the higher kind (float).
378+
//
379+
// AFFECTED ENTRIES (12 total - all unsigned array + signed scalar):
380+
//
381+
// | Array Type | Scalar Types | NumPy 1.x Result | NumPy 2.x Result |
382+
// |------------|-------------------|------------------|------------------|
383+
// | uint8 | int16/int32/int64 | int16/int32/int64| uint8 |
384+
// | uint16 | int16/int32/int64 | int32/int32/int64| uint16 |
385+
// | uint32 | int16/int32/int64 | int64/int64/int64| uint32 |
386+
// | uint64 | int16/int32/int64 | float64 (!) | uint64 |
387+
//
388+
// Verified against NumPy 2.4.2:
389+
// >>> (np.array([1,2,3], np.uint8) + 5).dtype
390+
// dtype('uint8')
391+
//
392+
// ============================================================================
393+
246394
var typemap_arr_scalar = new Dictionary<(Type, Type), Type>();
247395
typemap_arr_scalar.Add((np.@bool, np.@bool), np.@bool);
248396
typemap_arr_scalar.Add((np.@bool, np.uint8), np.uint8);
@@ -259,11 +407,11 @@ static np()
259407
typemap_arr_scalar.Add((np.uint8, np.@bool), np.uint8);
260408
typemap_arr_scalar.Add((np.uint8, np.uint8), np.uint8);
261409
typemap_arr_scalar.Add((np.uint8, np.@char), np.uint8);
262-
typemap_arr_scalar.Add((np.uint8, np.int16), np.int16);
410+
typemap_arr_scalar.Add((np.uint8, np.int16), np.uint8);
263411
typemap_arr_scalar.Add((np.uint8, np.uint16), np.uint8);
264-
typemap_arr_scalar.Add((np.uint8, np.int32), np.int32);
412+
typemap_arr_scalar.Add((np.uint8, np.int32), np.uint8);
265413
typemap_arr_scalar.Add((np.uint8, np.uint32), np.uint8);
266-
typemap_arr_scalar.Add((np.uint8, np.int64), np.int64);
414+
typemap_arr_scalar.Add((np.uint8, np.int64), np.uint8);
267415
typemap_arr_scalar.Add((np.uint8, np.uint64), np.uint8);
268416
typemap_arr_scalar.Add((np.uint8, np.float32), np.float32);
269417
typemap_arr_scalar.Add((np.uint8, np.float64), np.float64);
@@ -298,11 +446,11 @@ static np()
298446
typemap_arr_scalar.Add((np.uint16, np.@bool), np.uint16);
299447
typemap_arr_scalar.Add((np.uint16, np.uint8), np.uint16);
300448
typemap_arr_scalar.Add((np.uint16, np.@char), np.uint16);
301-
typemap_arr_scalar.Add((np.uint16, np.int16), np.int32);
449+
typemap_arr_scalar.Add((np.uint16, np.int16), np.uint16);
302450
typemap_arr_scalar.Add((np.uint16, np.uint16), np.uint16);
303-
typemap_arr_scalar.Add((np.uint16, np.int32), np.int32);
451+
typemap_arr_scalar.Add((np.uint16, np.int32), np.uint16);
304452
typemap_arr_scalar.Add((np.uint16, np.uint32), np.uint16);
305-
typemap_arr_scalar.Add((np.uint16, np.int64), np.int64);
453+
typemap_arr_scalar.Add((np.uint16, np.int64), np.uint16);
306454
typemap_arr_scalar.Add((np.uint16, np.uint64), np.uint16);
307455
typemap_arr_scalar.Add((np.uint16, np.float32), np.float32);
308456
typemap_arr_scalar.Add((np.uint16, np.float64), np.float64);
@@ -324,11 +472,11 @@ static np()
324472
typemap_arr_scalar.Add((np.uint32, np.@bool), np.uint32);
325473
typemap_arr_scalar.Add((np.uint32, np.uint8), np.uint32);
326474
typemap_arr_scalar.Add((np.uint32, np.@char), np.uint32);
327-
typemap_arr_scalar.Add((np.uint32, np.int16), np.int64);
475+
typemap_arr_scalar.Add((np.uint32, np.int16), np.uint32);
328476
typemap_arr_scalar.Add((np.uint32, np.uint16), np.uint32);
329-
typemap_arr_scalar.Add((np.uint32, np.int32), np.int64);
477+
typemap_arr_scalar.Add((np.uint32, np.int32), np.uint32);
330478
typemap_arr_scalar.Add((np.uint32, np.uint32), np.uint32);
331-
typemap_arr_scalar.Add((np.uint32, np.int64), np.int64);
479+
typemap_arr_scalar.Add((np.uint32, np.int64), np.uint32);
332480
typemap_arr_scalar.Add((np.uint32, np.uint64), np.uint32);
333481
typemap_arr_scalar.Add((np.uint32, np.float32), np.float64);
334482
typemap_arr_scalar.Add((np.uint32, np.float64), np.float64);
@@ -350,11 +498,11 @@ static np()
350498
typemap_arr_scalar.Add((np.uint64, np.@bool), np.uint64);
351499
typemap_arr_scalar.Add((np.uint64, np.uint8), np.uint64);
352500
typemap_arr_scalar.Add((np.uint64, np.@char), np.uint64);
353-
typemap_arr_scalar.Add((np.uint64, np.int16), np.float64);
501+
typemap_arr_scalar.Add((np.uint64, np.int16), np.uint64);
354502
typemap_arr_scalar.Add((np.uint64, np.uint16), np.uint64);
355-
typemap_arr_scalar.Add((np.uint64, np.int32), np.float64);
503+
typemap_arr_scalar.Add((np.uint64, np.int32), np.uint64);
356504
typemap_arr_scalar.Add((np.uint64, np.uint32), np.uint64);
357-
typemap_arr_scalar.Add((np.uint64, np.int64), np.float64);
505+
typemap_arr_scalar.Add((np.uint64, np.int64), np.uint64);
358506
typemap_arr_scalar.Add((np.uint64, np.uint64), np.uint64);
359507
typemap_arr_scalar.Add((np.uint64, np.float32), np.float64);
360508
typemap_arr_scalar.Add((np.uint64, np.float64), np.float64);

0 commit comments

Comments
 (0)