Skip to content

Conversation

@hawkinsp
Copy link
Contributor

  • We can call .results without figuring out whether we have an Operation or an OpView, and that's likely the common case anyway.
  • If we have one or more results, we can return them directly, with no need for a call to get_op_result_or_value. We're guaranteed that .results returns a PyOpResultList, so we have either an OpResult or sequence of OpResults, just as the API expects.

This saves a few 100ms during IR construction in an LLM JAX benchmark.

* We can call .results without figuring out whether we have an Operation
  or an OpView, and that's likely the common case anyway.
* If we have one or more results, we can return them directly, with no
  need for a call to get_op_result_or_value. We're guaranteed that
  .results returns a PyOpResultList, so we have either an OpResult or
  sequence of OpResults, just as the API expects.

This saves a few 100ms during IR construction in an LLM JAX benchmark.
@llvmbot
Copy link
Member

llvmbot commented Jan 22, 2025

@llvm/pr-subscribers-mlir

Author: Peter Hawkins (hawkinsp)

Changes
  • We can call .results without figuring out whether we have an Operation or an OpView, and that's likely the common case anyway.
  • If we have one or more results, we can return them directly, with no need for a call to get_op_result_or_value. We're guaranteed that .results returns a PyOpResultList, so we have either an OpResult or sequence of OpResults, just as the API expects.

This saves a few 100ms during IR construction in an LLM JAX benchmark.


Full diff: https://github.com/llvm/llvm-project/pull/123866.diff

1 Files Affected:

  • (modified) mlir/python/mlir/dialects/_ods_common.py (+10-9)
diff --git a/mlir/python/mlir/dialects/_ods_common.py b/mlir/python/mlir/dialects/_ods_common.py
index d40d936cdc83d6..5b67ab03d6f494 100644
--- a/mlir/python/mlir/dialects/_ods_common.py
+++ b/mlir/python/mlir/dialects/_ods_common.py
@@ -133,15 +133,16 @@ def get_op_results_or_values(
 def get_op_result_or_op_results(
     op: _Union[_cext.ir.OpView, _cext.ir.Operation],
 ) -> _Union[_cext.ir.Operation, _cext.ir.OpResult, _Sequence[_cext.ir.OpResult]]:
-    if isinstance(op, _cext.ir.OpView):
-        op = op.operation
-    return (
-        list(get_op_results_or_values(op))
-        if len(op.results) > 1
-        else get_op_result_or_value(op)
-        if len(op.results) > 0
-        else op
-    )
+    results = op.results
+    num_results = len(results)
+    if num_results == 1:
+        return results[0]
+    elif num_results > 1:
+        return results
+    elif isinstance(op, _cext.ir.OpView):
+        return op.operation
+    else:
+        return op
 
 ResultValueTypeTuple = _cext.ir.Operation, _cext.ir.OpView, _cext.ir.Value
 ResultValueT = _Union[ResultValueTypeTuple]

Copy link
Member

@jpienaar jpienaar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice, thanks

@jpienaar jpienaar merged commit ff0f1dd into llvm:main Jan 22, 2025
11 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

mlir:python MLIR Python bindings mlir

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants