Skip to content

Commit 275c859

Browse files
committed
separate setitem_masked functions for scalars and vectors
1 parent d7dba29 commit 275c859

File tree

2 files changed

+43
-35
lines changed

2 files changed

+43
-35
lines changed

maps/src/FlatSkyMap.cxx

Lines changed: 21 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1089,28 +1089,31 @@ flatskymap_getitem_masked(const FlatSkyMap &skymap, const G3SkyMapMask &m)
10891089
return out;
10901090
}
10911091

1092+
static void
1093+
flatskymap_setitem_masked_scalar(FlatSkyMap &skymap, const G3SkyMapMask &m,
1094+
double dval)
1095+
{
1096+
g3_assert(m.IsCompatible(skymap));
1097+
1098+
for (auto i : skymap) {
1099+
if (m.at(i.first))
1100+
skymap[i.first] = dval;
1101+
}
1102+
}
1103+
10921104
static void
10931105
flatskymap_setitem_masked(FlatSkyMap &skymap, const G3SkyMapMask &m,
1094-
py::object val)
1106+
const std::vector<double> &val)
10951107
{
10961108
g3_assert(m.IsCompatible(skymap));
10971109

1098-
if (py::extract<double>(val).check()) {
1099-
double dval = py::extract<double>(val)();
1100-
for (auto i : skymap) {
1101-
if (m.at(i.first))
1102-
skymap[i.first] = dval;
1103-
}
1104-
} else {
1105-
// XXX: the iterable case probably be optimized for numpy arrays
1106-
// XXX: check for size congruence first?
1107-
size_t j = 0;
1108-
for (auto i : skymap) {
1109-
if (m.at(i.first)) {
1110-
skymap[i.first] = py::extract<double>(val[j])();
1111-
j++;
1112-
}
1113-
}
1110+
if (val.size() != m.sum())
1111+
throw py::value_error("Item dimensions do not match masked area");
1112+
1113+
size_t j = 0;
1114+
for (auto i : skymap) {
1115+
if (m.at(i.first))
1116+
skymap[i.first] = val[j++];
11141117
}
11151118
}
11161119

@@ -1373,6 +1376,7 @@ PYBINDINGS("maps", scope)
13731376
.def("__setitem__", flatskymap_setslice_1d)
13741377
.def("__getitem__", flatskymap_getitem_masked)
13751378
.def("__setitem__", flatskymap_setitem_masked)
1379+
.def("__setitem__", flatskymap_setitem_masked_scalar)
13761380
;
13771381

13781382
// Add buffer protocol interface

maps/src/HealpixSkyMap.cxx

Lines changed: 22 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1101,28 +1101,31 @@ HealpixSkyMap_getitem_masked(const HealpixSkyMap &skymap, const G3SkyMapMask &m)
11011101
return out;
11021102
}
11031103

1104+
static void
1105+
HealpixSkyMap_setitem_masked_scalar(HealpixSkyMap &skymap, const G3SkyMapMask &m,
1106+
double dval)
1107+
{
1108+
g3_assert(m.IsCompatible(skymap));
1109+
1110+
for (auto i : skymap) {
1111+
if (m.at(i.first))
1112+
skymap[i.first] = dval;
1113+
}
1114+
}
1115+
11041116
static void
11051117
HealpixSkyMap_setitem_masked(HealpixSkyMap &skymap, const G3SkyMapMask &m,
1106-
py::object val)
1118+
const std::vector<double> &val)
11071119
{
11081120
g3_assert(m.IsCompatible(skymap));
11091121

1110-
if (py::extract<double>(val).check()) {
1111-
double dval = py::extract<double>(val)();
1112-
for (auto i : skymap) {
1113-
if (m.at(i.first))
1114-
skymap[i.first] = dval;
1115-
}
1116-
} else {
1117-
// XXX: the iterable case probably be optimized for numpy arrays
1118-
// XXX: check for size congruence first?
1119-
size_t j = 0;
1120-
for (auto i : skymap) {
1121-
if (m.at(i.first)) {
1122-
skymap[i.first] = py::extract<double>(val[j])();
1123-
j++;
1124-
}
1125-
}
1122+
if (val.size() != m.sum())
1123+
throw py::value_error("Item dimensions do not match masked area");
1124+
1125+
size_t j = 0;
1126+
for (auto i : skymap) {
1127+
if (m.at(i.first))
1128+
skymap[i.first] = val[j++];
11261129
}
11271130
}
11281131

@@ -1226,9 +1229,9 @@ PYBINDINGS("maps", scope)
12261229
.def(py::init<const std::vector<uint64_t> &, const std::vector<double> &,
12271230
size_t, bool, bool, MapCoordReference, G3Timestream::TimestreamUnits,
12281231
G3SkyMap::MapPolType, G3SkyMap::MapPolConv>(),
1229-
py::arg("nside"),
12301232
py::arg("index"),
12311233
py::arg("data"),
1234+
py::arg("nside"),
12321235
py::arg("weighted") = true,
12331236
py::arg("nested") = false,
12341237
py::arg("coord_ref") = MapCoordReference::Equatorial,
@@ -1280,6 +1283,7 @@ PYBINDINGS("maps", scope)
12801283
.def("__setitem__", &skymap_setitem)
12811284
.def("__getitem__", HealpixSkyMap_getitem_masked)
12821285
.def("__setitem__", HealpixSkyMap_setitem_masked)
1286+
.def("__setitem__", HealpixSkyMap_setitem_masked_scalar)
12831287

12841288
.def("nonzero_pixels", &HealpixSkyMap_nonzeropixels,
12851289
"Returns a list of the indices of the non-zero pixels in the "

0 commit comments

Comments
 (0)