|
4 | 4 | from pytensor.graph.basic import Apply, Constant |
5 | 5 | from pytensor.graph.op import Op |
6 | 6 | from pytensor.misc.safe_asarray import _asarray |
7 | | -from pytensor.tensor.basic import arange, as_tensor_variable, flatten, switch |
| 7 | +from pytensor.tensor.basic import arange, as_tensor_variable, switch |
8 | 8 | from pytensor.tensor.math import eq, ge, mul |
9 | | -from pytensor.tensor.shape import shape |
10 | | -from pytensor.tensor.subtensor import set_subtensor |
11 | | -from pytensor.tensor.type import TensorType, integer_dtypes |
| 9 | +from pytensor.tensor.type import TensorType |
12 | 10 |
|
13 | 11 |
|
14 | 12 | def _variable_is_none(var): |
@@ -304,270 +302,3 @@ def _topk_py_impl(op, x, k, axis, idx_dtype): |
304 | 302 | else: |
305 | 303 | zi = np.argpartition(x, -k, axis=axis)[tuple(idx)] |
306 | 304 | return zi.astype(idx_dtype) |
307 | | - |
308 | | - |
309 | | -class TopKOp(Op): |
310 | | - """Operations related to finding k-largest elements. |
311 | | -
|
312 | | - Parameters |
313 | | - ---------- |
314 | | - axis: integer |
315 | | - Defaults to ``-1``. |
316 | | - The axis to perform the operation. Must be in range ``[-ndim, ndim)``, where |
317 | | - ``ndim`` is the dimensionality of input tensor. |
318 | | -
|
319 | | - idx_dtype: string |
320 | | - Specify output dtype for indices, defaults to ``int64``, must be integer type. |
321 | | -
|
322 | | - sorted: bool |
323 | | - NOTE: NOT IMPLEMENTED YET |
324 | | - Defaults to ``True`` |
325 | | -
|
326 | | - If True, the result array would be sorted in descending order. |
327 | | -
|
328 | | -
|
329 | | - Notes |
330 | | - ----- |
331 | | - - The output order is not guaranteed. On the CPU, we use |
332 | | - ``np.partition`` and ``np.argpartition`` that only make sure the |
333 | | - k-th element is the correct one and that the other |
334 | | - elements are on the correct side. |
335 | | - - By default, this Op gives two outputs: values and indices. However |
336 | | - optimizers may remove a certain output if not needed. |
337 | | - - Computing the gradient requests the computation of the indices in |
338 | | - forward pass. |
339 | | - - If the top-k-th value is not unique, we cannot guarantee the |
340 | | - output indices being deterministically chosen. |
341 | | -
|
342 | | - See Also |
343 | | - -------- |
344 | | - topk |
345 | | - argtopk |
346 | | - argtopk_and_topk |
347 | | -
|
348 | | - """ |
349 | | - |
350 | | - # TODO more params |
351 | | - """ |
352 | | - only_top_kth: bool |
353 | | - Defaults to ``False`` |
354 | | -
|
355 | | - If ``True``, will only find one exact top k-th element on given axis. |
356 | | -
|
357 | | - """ |
358 | | - |
359 | | - # TODO c_code |
360 | | - # TODO add opt, if k==1, use max/min reduce |
361 | | - # also if k is axis size, just copy input tensor |
362 | | - # TODO add opt, to merge argtopk / topk |
363 | | - __props__ = ("axis", "sorted", "return_values", "return_indices", "idx_dtype") |
364 | | - |
365 | | - def __init__( |
366 | | - self, |
367 | | - axis=-1, |
368 | | - sorted=True, |
369 | | - idx_dtype="int64", |
370 | | - return_values=True, |
371 | | - return_indices=True, |
372 | | - ): |
373 | | - # numpy always uses int64 as output dtype for arg*() routines |
374 | | - # however, we add "idx_dtype" param as memory is more precious on gpu |
375 | | - if not isinstance(axis, int): |
376 | | - raise TypeError(f'"axis" parameter must be integer, got "{type(axis)}"') |
377 | | - if sorted: |
378 | | - raise NotImplementedError( |
379 | | - "The sorted parameter is not yet implemented. Use sorted=False for now." |
380 | | - ) |
381 | | - if idx_dtype not in integer_dtypes: |
382 | | - raise TypeError( |
383 | | - f'"idx_dtype" parameter must be an integer dtype, got "{idx_dtype}"' |
384 | | - ) |
385 | | - |
386 | | - if not (return_indices or return_values): |
387 | | - raise ValueError( |
388 | | - "Neither return_values nor return_indices is True, this isn't allowed" |
389 | | - ) |
390 | | - |
391 | | - self.axis = axis |
392 | | - self.sorted = sorted |
393 | | - self.return_values = return_values |
394 | | - self.return_indices = return_indices |
395 | | - self.idx_dtype = idx_dtype |
396 | | - |
397 | | - def __str__(self): |
398 | | - return "%(op)s{axis=%(axis)d, sorted=%(sorted)s}" % dict( |
399 | | - op=self.__class__.__name__, axis=self.axis, sorted=self.sorted |
400 | | - ) |
401 | | - |
402 | | - def make_node(self, inp, kth): |
403 | | - inp = as_tensor_variable(inp) |
404 | | - ndim = inp.ndim |
405 | | - if ndim == 0: |
406 | | - raise ValueError("Cannot take scalar as input") |
407 | | - if not -ndim <= self.axis < ndim: |
408 | | - raise IndexError( |
409 | | - '"axis" parameter out of range,' |
410 | | - f" expected integer within [{int(-ndim)}, {int(ndim - 1)}]" |
411 | | - ) |
412 | | - |
413 | | - kth = as_tensor_variable(kth) |
414 | | - _check_tensor_is_scalar(kth) |
415 | | - outs = [] |
416 | | - if self.return_values: |
417 | | - outs.append( |
418 | | - TensorType(dtype=inp.type.dtype, shape=(None,) * inp.type.ndim)() |
419 | | - ) |
420 | | - if self.return_indices: |
421 | | - outs.append( |
422 | | - TensorType(dtype=self.idx_dtype, shape=(None,) * inp.type.ndim)() |
423 | | - ) |
424 | | - return Apply(self, [inp, kth], outs) |
425 | | - |
426 | | - def perform(self, node, inputs, output_storage): |
427 | | - x, k = inputs |
428 | | - axis = self.axis |
429 | | - if not self.return_indices: |
430 | | - pzv = output_storage[0] |
431 | | - pzv[0] = _topk_py_impl(self, x, k, axis, None) |
432 | | - elif self.return_values: |
433 | | - pzv = output_storage[0] |
434 | | - pzi = output_storage[1] |
435 | | - pzv[0], pzi[0] = _topk_py_impl(self, x, k, axis, node.outputs[1].dtype) |
436 | | - else: |
437 | | - pzi = output_storage[0] |
438 | | - pzi[0] = _topk_py_impl(self, x, k, axis, node.outputs[0].dtype) |
439 | | - |
440 | | - def infer_shape(self, fgraph, node, inp_shapes): |
441 | | - shp = list(inp_shapes[0]) |
442 | | - shp[self.axis] = np.abs(node.inputs[1]) |
443 | | - shp = tuple(shp) |
444 | | - return [shp for i in [self.return_values, self.return_indices] if i] |
445 | | - |
446 | | - def L_op(self, inputs, outputs, out_grads): |
447 | | - x, k = inputs |
448 | | - k_grad = grad_undefined(self, 1, k, "topk: k is not differentiable") |
449 | | - |
450 | | - if not (self.return_indices or self.return_values): |
451 | | - x_grad = grad_undefined( |
452 | | - self, |
453 | | - 0, |
454 | | - x, |
455 | | - "topk: cannot get gradient without both indices and values", |
456 | | - ) |
457 | | - else: |
458 | | - x_shp = shape(x) |
459 | | - z_grad = out_grads[0] |
460 | | - ndim = x.ndim |
461 | | - axis = self.axis % ndim |
462 | | - grad_indices = [ |
463 | | - arange(x_shp[i]).dimshuffle([0] + ["x"] * (ndim - i - 1)) |
464 | | - if i != axis |
465 | | - else outputs[-1] |
466 | | - for i in range(ndim) |
467 | | - ] |
468 | | - x_grad = x.zeros_like(dtype=z_grad.dtype) |
469 | | - x_grad = set_subtensor(x_grad[tuple(grad_indices)], z_grad) |
470 | | - |
471 | | - return [x_grad, k_grad] |
472 | | - |
473 | | - |
474 | | -def topk(x, kth, axis=-1, sorted=True, idx_dtype="int64"): |
475 | | - """ |
476 | | - Returns the k-largest elements along an axis. |
477 | | -
|
478 | | - Parameters |
479 | | - ---------- |
480 | | -
|
481 | | - x: tensor instance |
482 | | -
|
483 | | - kth: integer constant/variable |
484 | | - Must not be 0. If negative, gives k-smallest elements instead. |
485 | | -
|
486 | | - axis: integer or ``None`` |
487 | | - Upon which axis shall the operation be performed on. |
488 | | - If ``None``, works on flattened array. |
489 | | -
|
490 | | - sorted: bool |
491 | | - NOTE: NOT IMPLEMENTED YET, USE ``False`` FOR NOW. |
492 | | - Defaults to ``True`` |
493 | | -
|
494 | | - If True, the result array would be sorted in descending order. |
495 | | -
|
496 | | - idx_dtype: string |
497 | | - Specify output dtype used in indices, defaults to ``int64``, must be integer type. |
498 | | - This option is here because indices are needed for gradient. |
499 | | -
|
500 | | - Returns |
501 | | - ------- |
502 | | - Tensor variable with same dtype as `x`. |
503 | | -
|
504 | | - Notes |
505 | | - ----- |
506 | | - - ``sorted=True`` is not supported yet. |
507 | | -
|
508 | | - """ |
509 | | - if axis is None: |
510 | | - x = flatten(x) |
511 | | - axis = 0 |
512 | | - return TopKOp(axis=axis, sorted=sorted, idx_dtype=idx_dtype)(x, kth)[0] |
513 | | - |
514 | | - |
515 | | -def argtopk(x, kth, axis=-1, sorted=True, idx_dtype="int64"): |
516 | | - """ |
517 | | - Returns the indices of k-largest elements along an axis. |
518 | | -
|
519 | | - Parameters |
520 | | - ---------- |
521 | | -
|
522 | | - x: tensor instance |
523 | | -
|
524 | | - kth: integer constant/variable |
525 | | - Must not be 0. If negative, gives k-smallest elements instead. |
526 | | -
|
527 | | - sorted: bool |
528 | | - NOTE: NOT IMPLEMENTED YET, USE ``False`` FOR NOW. |
529 | | - Defaults to ``True`` |
530 | | -
|
531 | | - If True, the result array of corresponding indices would be sorted in descending order. |
532 | | -
|
533 | | -
|
534 | | - axis: integer, tuple/list of integers, or ``None`` |
535 | | - Upon which axis shall the operation be performed on. |
536 | | - If ``None``, works on flattened array. |
537 | | -
|
538 | | - idx_dtype: string |
539 | | - Specify output dtype, defaults to ``int64``, must be integer type. |
540 | | -
|
541 | | - Returns |
542 | | - ------- |
543 | | - Tensor variable with dtype specified in `idx_dtype`. |
544 | | -
|
545 | | - Notes |
546 | | - ----- |
547 | | - - ``sorted=True`` is not supported yet. |
548 | | -
|
549 | | - - If the top-k-th value is not unique, we cannot guarantee the output |
550 | | - indices are deterministically chosen. |
551 | | -
|
552 | | - """ |
553 | | - if axis is None: |
554 | | - x = flatten(x) |
555 | | - axis = 0 |
556 | | - return TopKOp(axis=axis, sorted=sorted, idx_dtype=idx_dtype)(x, kth)[1] |
557 | | - |
558 | | - |
559 | | -def topk_and_argtopk(x, kth, axis=-1, sorted=True, idx_dtype="int64"): |
560 | | - """ |
561 | | - Returns the results of both topk() and argtopk() in one Op. |
562 | | -
|
563 | | - See the respective documentation for details. |
564 | | -
|
565 | | - Returns |
566 | | - ------- |
567 | | - tuple: (values, indices) |
568 | | -
|
569 | | - """ |
570 | | - if axis is None: |
571 | | - x = flatten(x) |
572 | | - axis = 0 |
573 | | - return TopKOp(axis=axis, sorted=sorted, idx_dtype=idx_dtype)(x, kth) |
0 commit comments