@@ -1431,6 +1431,12 @@ class Alloc(COp):
1431
1431
1432
1432
__props__ = ()
1433
1433
1434
+ _runtime_broadcast_error_msg = (
1435
+ "Runtime broadcasting not allowed. "
1436
+ "The output of Alloc requires broadcasting a dimension of the input value, which was not marked as broadcastable. "
1437
+ "If broadcasting was intended, use `specify_broadcastable` on the relevant input."
1438
+ )
1439
+
1434
1440
def make_node (self , value , * shape ):
1435
1441
value = as_tensor_variable (value )
1436
1442
shape , static_shape = infer_static_shape (shape )
@@ -1468,10 +1474,21 @@ def make_node(self, value, *shape):
1468
1474
otype = TensorType (dtype = value .dtype , shape = combined_static_shape )
1469
1475
return Apply (self , [value ] + shape , [otype ()])
1470
1476
1477
+ @staticmethod
1478
+ def _check_runtime_broadcast (node , value , shape ):
1479
+ value_static_shape = node .inputs [0 ].type .shape
1480
+ for v_static_dim , value_dim , out_dim in zip (
1481
+ value_static_shape [::- 1 ], value .shape [::- 1 ], shape [::- 1 ]
1482
+ ):
1483
+ if v_static_dim is None and value_dim == 1 and out_dim != 1 :
1484
+ raise ValueError (Alloc ._runtime_broadcast_error_msg )
1485
+
1471
1486
def perform (self , node , inputs , out_ ):
1472
1487
(out ,) = out_
1473
1488
v = inputs [0 ]
1474
1489
sh = tuple ([int (i ) for i in inputs [1 :]])
1490
+ self ._check_runtime_broadcast (node , v , sh )
1491
+
1475
1492
if out [0 ] is None or out [0 ].shape != sh :
1476
1493
if v .size == 1 and v .item () == 0 :
1477
1494
out [0 ] = np .zeros (sh , dtype = v .dtype )
@@ -1484,12 +1501,19 @@ def perform(self, node, inputs, out_):
1484
1501
1485
1502
def c_code (self , node , name , inp , out , sub ):
1486
1503
vv = inp [0 ]
1487
- ndim = len (inp [1 :])
1488
1504
(zz ,) = out
1489
1505
fail = sub ["fail" ]
1490
1506
1507
+ v_static_shape = node .inputs [0 ].type .shape
1508
+ o_static_shape = node .outputs [0 ].type .shape
1509
+ v_ndim = len (v_static_shape )
1510
+ o_ndim = len (o_static_shape )
1511
+ assert o_ndim == len (inp [1 :])
1512
+
1513
+ # Declare variables
1491
1514
code = f"""
1492
- npy_intp shape[{ ndim } ];
1515
+ npy_intp shape[{ o_ndim } ];
1516
+ int need_new_out;
1493
1517
"""
1494
1518
1495
1519
# Initialize shape
@@ -1498,15 +1522,26 @@ def c_code(self, node, name, inp, out, sub):
1498
1522
shape[{ i } ] = ((dtype_{ shp_i } *) PyArray_DATA({ shp_i } ))[0];
1499
1523
"""
1500
1524
1525
+ # Add checks for runtime broadcasting
1526
+ for i , v_static_dim in enumerate (v_static_shape [::- 1 ]):
1527
+ if v_static_dim is None :
1528
+ code += f"""
1529
+ if (PyArray_DIMS({ vv } )[{ v_ndim - i - 1 } ] == 1 && shape[{ o_ndim - i - 1 } ] != 1)
1530
+ {{
1531
+ PyErr_Format(PyExc_ValueError, "{ self ._runtime_broadcast_error_msg } ");
1532
+ { fail }
1533
+ }}
1534
+ """
1535
+
1501
1536
code += f"""
1502
- int need_new_out = (NULL == { zz } );
1503
- for (int i = 0; i < { ndim } ; i++)
1537
+ need_new_out = (NULL == { zz } );
1538
+ for (int i = 0; i < { o_ndim } ; i++)
1504
1539
need_new_out = (need_new_out || (PyArray_DIMS({ zz } )[i] != shape[i]));
1505
1540
1506
1541
if (need_new_out)
1507
1542
{{
1508
1543
Py_XDECREF({ zz } );
1509
- { zz } = (PyArrayObject*) PyArray_SimpleNew({ ndim } , shape, PyArray_TYPE({ vv } ));
1544
+ { zz } = (PyArrayObject*) PyArray_SimpleNew({ o_ndim } , shape, PyArray_TYPE({ vv } ));
1510
1545
if (!{ zz } )
1511
1546
{{
1512
1547
PyErr_SetString(PyExc_MemoryError, "alloc failed");
@@ -1522,7 +1557,7 @@ def c_code(self, node, name, inp, out, sub):
1522
1557
return code
1523
1558
1524
1559
def c_code_cache_version (self ):
1525
- return (3 ,)
1560
+ return (4 ,)
1526
1561
1527
1562
def infer_shape (self , fgraph , node , input_shapes ):
1528
1563
return [node .inputs [1 :]]
0 commit comments