Skip to content

Commit a85c6b1

Browse files
committed
add query method
1 parent 81dcdd2 commit a85c6b1

File tree

2 files changed

+227
-1
lines changed

2 files changed

+227
-1
lines changed

h5json/hdf5db.py

Lines changed: 175 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2183,6 +2183,181 @@ def getDatasetValuesByUuid(self, obj_uuid, slices=Ellipsis, format="json"):
21832183
values = values.tobytes()
21842184

21852185
return values
2186+
2187+
"""
2188+
doDatasetQueryByUuid: return rows based on query string
2189+
Return rows from a dataset that matches query string.
2190+
2191+
Note: Only supported for compound_type/one-dimensional datasets
2192+
"""
2193+
def doDatasetQueryByUuid(self, obj_uuid, query, start=0, stop=-1, step=1, limit=None):
2194+
self.log.info("doQueryByUuid - uuid: " + obj_uuid + " query:" + query)
2195+
self.log.info("start: " + str(start) + " stop: " + str(stop) + " step: " + str(step) + " limit: " + str(limit))
2196+
dset = self.getDatasetObjByUuid(obj_uuid)
2197+
if dset is None:
2198+
msg = "Dataset: " + obj_uuid + " not found"
2199+
self.log.info(msg)
2200+
raise IOError(errno.ENXIO, msg)
2201+
2202+
values = []
2203+
dt = dset.dtype
2204+
typeItem = getTypeItem(dt)
2205+
itemSize = getItemSize(typeItem)
2206+
if typeItem['class'] != "H5T_COMPOUND":
2207+
msg = "Only compound type datasets can be used as query target"
2208+
self.log.info(msg)
2209+
raise IOError(errno.EINVAL, msg)
2210+
2211+
if dset.shape is None:
2212+
# null space dataset (with h5py 2.6.0)
2213+
return None
2214+
2215+
rank = len(dset.shape)
2216+
if rank != 1:
2217+
msg = "One one-dimensional datasets can be used as query target"
2218+
self.log.info(msg)
2219+
raise IOError(errno.EINVAL, msg)
2220+
2221+
2222+
values = []
2223+
indexes = []
2224+
count = 0
2225+
2226+
num_elements = dset.shape[0]
2227+
if stop == -1:
2228+
stop = num_elements
2229+
elif stop > num_elements:
2230+
stop = num_elements
2231+
block_size = self._getBlockSize(dset)
2232+
self.log.info("block_size: " + str(block_size))
2233+
2234+
field_names = list(dset.dtype.fields.keys())
2235+
eval_str = self._getEvalStr(query, field_names)
2236+
2237+
while start < stop:
2238+
if limit and (count == limit):
2239+
break # no more rows for this batch
2240+
end = start + block_size
2241+
if end > stop:
2242+
end = stop
2243+
rows = dset[start:end] # read from dataset
2244+
where_result = np.where(eval(eval_str))
2245+
index = where_result[0].tolist()
2246+
if len(index) > 0:
2247+
for i in index:
2248+
row = rows[i]
2249+
item = self.bytesArrayToList(row)
2250+
values.append(item)
2251+
indexes.append(start + i)
2252+
count += 1
2253+
if limit and (count == limit):
2254+
break # no more rows for this batch
2255+
2256+
start = end # go to next block
2257+
2258+
2259+
# values = self.getDataValue(item_type, values, dimension=1, dims=(len(values),))
2260+
2261+
self.log.info("got " + str(count) + " query matches")
2262+
return (indexes, values)
2263+
2264+
"""
2265+
_getBlockSize: Get number of rows to read from disk
2266+
2267+
heurestic to get reasonable sized chunk of data to fetch.
2268+
make multiple of chunk_size if possible
2269+
"""
2270+
def _getBlockSize(self, dset):
2271+
target_block_size = 256 * 1000
2272+
if dset.chunks:
2273+
chunk_size = dset.chunks[0]
2274+
if chunk_size < target_block_size:
2275+
block_size = (target_block_size // chunk_size) * chunk_size
2276+
else:
2277+
block_size = target_block_size
2278+
else:
2279+
block_size = target_block_size
2280+
return block_size
2281+
2282+
"""
2283+
_getEvalStr: Get eval string for given query
2284+
2285+
Gets Eval string to use with numpy where method.
2286+
"""
2287+
def _getEvalStr(self, query, field_names):
2288+
i = 0
2289+
eval_str = ""
2290+
var_name = None
2291+
end_quote_char = None
2292+
var_count = 0
2293+
paren_count = 0
2294+
black_list = ( "import", ) # field names that are not allowed
2295+
self.log.info("getEvalStr(" + query + ")")
2296+
for item in black_list:
2297+
if item in field_names:
2298+
msg = "invalid field name"
2299+
self.log.info("EINVAL: " + msg)
2300+
raise IOError(errno.EINVAL, msg)
2301+
while i < len(query):
2302+
ch = query[i]
2303+
if (i+1) < len(query):
2304+
ch_next = query[i+1]
2305+
else:
2306+
ch_next = None
2307+
if var_name and not ch.isalnum():
2308+
# end of variable
2309+
if var_name not in field_names:
2310+
# invalid
2311+
msg = "unknown field name"
2312+
self.log.info("EINVAL: " + msg)
2313+
raise IOError(errno.EINVAL, msg)
2314+
eval_str += "rows['" + var_name + "']"
2315+
var_name = None
2316+
var_count += 1
2317+
2318+
if end_quote_char:
2319+
if ch == end_quote_char:
2320+
# end of literal
2321+
end_quote_char = None
2322+
eval_str += ch
2323+
elif ch in ("'", '"'):
2324+
end_quote_char = ch
2325+
eval_str += ch
2326+
elif ch.isalpha():
2327+
if ch == 'b' and ch_next in ("'", '"'):
2328+
eval_str += 'b' # start of a byte string literal
2329+
elif var_name is None:
2330+
var_name = ch # start of a variable
2331+
else:
2332+
var_name += ch
2333+
elif ch == '(' and end_quote_char is None:
2334+
paren_count += 1
2335+
eval_str += ch
2336+
elif ch == ')' and end_quote_char is None:
2337+
paren_count -= 1
2338+
if paren_count < 0:
2339+
msg = "Mismatched paren"
2340+
self.log.info("EINVAL: " + msg)
2341+
raise IOError(errno.EINVAL, msg)
2342+
eval_str += ch
2343+
else:
2344+
# just add to eval_str
2345+
eval_str += ch
2346+
i = i+1
2347+
if end_quote_char:
2348+
msg = "no matching quote character"
2349+
self.log.info("EINVAL: " + msg)
2350+
raise IOError(errno.EINVAL, msg)
2351+
if var_count == 0:
2352+
msg = "No field value"
2353+
self.log.info("EINVAL: " + msg)
2354+
raise IOError(errno.EINVAL, msg)
2355+
if paren_count != 0:
2356+
msg = "Mismatched paren"
2357+
self.log.info("EINVAL: " + msg)
2358+
raise IOError(errno.EINVAL, msg)
2359+
2360+
return eval_str
21862361

21872362
"""
21882363
Get values from dataset identified by obj_uuid using the given

test/unit/hdf5dbTest.py

Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1162,7 +1162,58 @@ def testRootAcl(self):
11621162
self.assertEqual(acl['delete'], 0)
11631163
self.assertEqual(acl['readACL'], 0)
11641164
self.assertEqual(acl['updateACL'], 0)
1165-
1165+
1166+
def testGetEvalStr(self):
1167+
queries = { "date == 23": "rows['date'] == 23",
1168+
"wind == b'W 5'": "rows['wind'] == b'W 5'",
1169+
"temp > 61": "rows['temp'] > 61",
1170+
"(date >=22) & (date <= 24)": "(rows['date'] >=22) & (rows['date'] <= 24)",
1171+
"(date == 21) & (temp > 70)": "(rows['date'] == 21) & (rows['temp'] > 70)",
1172+
"(wind == b'E 7') | (wind == b'S 7')": "(rows['wind'] == b'E 7') | (rows['wind'] == b'S 7')" }
1173+
1174+
fields = ["date", "wind", "temp" ]
1175+
filepath = getFile('empty.h5', 'getevalstring.h5')
1176+
with Hdf5db(filepath, app_logger=self.log) as db:
1177+
1178+
for query in queries.keys():
1179+
eval_str = db._getEvalStr(query, fields)
1180+
self.assertEqual(eval_str, queries[query])
1181+
#print(query, "->", eval_str)
1182+
1183+
def testBadQuery(self):
1184+
queries = ( "foobar", # no variable used
1185+
"wind = b'abc", # non-closed literal
1186+
"(wind = b'N') & (temp = 32", # missing paren
1187+
"foobar > 42", # invalid field name
1188+
"import subprocess; subprocess.call(['ls', '/'])") # injection attack
1189+
1190+
fields = ("date", "wind", "temp" )
1191+
filepath = getFile('empty.h5', 'badquery.h5')
1192+
with Hdf5db(filepath, app_logger=self.log) as db:
1193+
1194+
for query in queries:
1195+
try:
1196+
eval_str = db._getEvalStr(query, fields)
1197+
self.assertTrue(False) # shouldn't get here
1198+
except IOError as e:
1199+
pass # ok
1200+
1201+
def testInjectionBlock(self):
1202+
queries = (
1203+
"import subprocess; subprocess.call(['ls', '/'])", ) # injection attack
1204+
1205+
fields = ("import", "subprocess", "call" )
1206+
filepath = getFile('empty.h5', 'injectionblock.h5')
1207+
with Hdf5db(filepath, app_logger=self.log) as db:
1208+
1209+
for query in queries:
1210+
try:
1211+
eval_str = db._getEvalStr(query, fields)
1212+
self.assertTrue(False) # shouldn't get here
1213+
except IOError as e:
1214+
pass # ok
1215+
1216+
11661217

11671218

11681219
if __name__ == '__main__':

0 commit comments

Comments
 (0)