Skip to content

Torch 2.2 breakage on bfloat16 and float16 #8

@proger

Description

@proger

Running the triton implementation with torch 2.2 on inputs of type float16 and bfloat16 result in the following error:

  File "/usr/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/usr/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/home/proger/accelerated-scan/accelerated_scan/triton.py", line 144, in <module>
    out = scan(gates, tokens)
  File "/home/proger/accelerated-scan/accelerated_scan/triton.py", line 129, in scan
    return Scan.apply(gates, tokens)
  File "/home/proger/.local/lib/python3.10/site-packages/torch/autograd/function.py", line 553, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
  File "/home/proger/accelerated-scan/accelerated_scan/triton.py", line 87, in forward
    forward_scan[(B,C)](gates, tokens, states, SEQUENCE_LENGTH=T, enable_fp_fusion=False)
  File "/home/proger/.local/lib/python3.10/site-packages/triton/runtime/jit.py", line 532, in run
    self.cache[device][key] = compile(
  File "/home/proger/.local/lib/python3.10/site-packages/triton/compiler/compiler.py", line 543, in compile
    next_module = compile_kernel(module)
  File "/home/proger/.local/lib/python3.10/site-packages/triton/compiler/compiler.py", line 435, in <lambda>
    ast_to_ttir(src, signature, configs[0], constants, debug=debug, target=target), target))
  File "/home/proger/.local/lib/python3.10/site-packages/triton/compiler/code_generator.py", line 1228, in ast_to_ttir
    generator.visit(fn.parse())
  File "/home/proger/.local/lib/python3.10/site-packages/triton/compiler/code_generator.py", line 1105, in visit
    ret = super().visit(node)
  File "/usr/lib/python3.10/ast.py", line 418, in visit
    return visitor(node)
  File "/home/proger/.local/lib/python3.10/site-packages/triton/compiler/code_generator.py", line 303, in visit_Module
    ast.NodeVisitor.generic_visit(self, node)
  File "/usr/lib/python3.10/ast.py", line 426, in generic_visit
    self.visit(item)
  File "/home/proger/.local/lib/python3.10/site-packages/triton/compiler/code_generator.py", line 1105, in visit
    ret = super().visit(node)
  File "/usr/lib/python3.10/ast.py", line 418, in visit
    return visitor(node)
  File "/home/proger/.local/lib/python3.10/site-packages/triton/compiler/code_generator.py", line 376, in visit_FunctionDef
    self.visit_compound_statement(node.body)
  File "/home/proger/.local/lib/python3.10/site-packages/triton/compiler/code_generator.py", line 298, in visit_compound_statement
    ret_type = self.visit(stmt)
  File "/home/proger/.local/lib/python3.10/site-packages/triton/compiler/code_generator.py", line 1105, in visit
    ret = super().visit(node)
  File "/usr/lib/python3.10/ast.py", line 418, in visit
    return visitor(node)
  File "/home/proger/.local/lib/python3.10/site-packages/triton/compiler/code_generator.py", line 428, in visit_Assign
    values = self.visit(node.value)
  File "/home/proger/.local/lib/python3.10/site-packages/triton/compiler/code_generator.py", line 1105, in visit
    ret = super().visit(node)
  File "/usr/lib/python3.10/ast.py", line 418, in visit
    return visitor(node)
  File "/home/proger/.local/lib/python3.10/site-packages/triton/compiler/code_generator.py", line 1021, in visit_Call
    return self.call_JitFunction(fn, args, kws)
  File "/home/proger/.local/lib/python3.10/site-packages/triton/compiler/code_generator.py", line 989, in call_JitFunction
    generator.visit(fn.parse())
  File "/home/proger/.local/lib/python3.10/site-packages/triton/compiler/code_generator.py", line 1105, in visit
    ret = super().visit(node)
  File "/usr/lib/python3.10/ast.py", line 418, in visit
    return visitor(node)
  File "/home/proger/.local/lib/python3.10/site-packages/triton/compiler/code_generator.py", line 303, in visit_Module
    ast.NodeVisitor.generic_visit(self, node)
  File "/usr/lib/python3.10/ast.py", line 426, in generic_visit
    self.visit(item)
  File "/home/proger/.local/lib/python3.10/site-packages/triton/compiler/code_generator.py", line 1105, in visit
    ret = super().visit(node)
  File "/usr/lib/python3.10/ast.py", line 418, in visit
    return visitor(node)
  File "/home/proger/.local/lib/python3.10/site-packages/triton/compiler/code_generator.py", line 376, in visit_FunctionDef
    self.visit_compound_statement(node.body)
  File "/home/proger/.local/lib/python3.10/site-packages/triton/compiler/code_generator.py", line 298, in visit_compound_statement
    ret_type = self.visit(stmt)
  File "/home/proger/.local/lib/python3.10/site-packages/triton/compiler/code_generator.py", line 1105, in visit
    ret = super().visit(node)
  File "/usr/lib/python3.10/ast.py", line 418, in visit
    return visitor(node)
  File "/home/proger/.local/lib/python3.10/site-packages/triton/compiler/code_generator.py", line 1069, in visit_Expr
    ast.NodeVisitor.generic_visit(self, node)
  File "/usr/lib/python3.10/ast.py", line 428, in generic_visit
    self.visit(value)
  File "/home/proger/.local/lib/python3.10/site-packages/triton/compiler/code_generator.py", line 1105, in visit
    ret = super().visit(node)
  File "/usr/lib/python3.10/ast.py", line 418, in visit
    return visitor(node)
  File "/home/proger/.local/lib/python3.10/site-packages/triton/compiler/code_generator.py", line 1012, in visit_Call
    return static_implementation(self, node)
  File "/home/proger/.local/lib/python3.10/site-packages/triton/compiler/code_generator.py", line 1140, in execute_static_assert
    raise CompileTimeAssertionFailure(None, node, _unwrap_if_constexpr(message))
triton.compiler.errors.CompileTimeAssertionFailure: at 2:4:def forward_scan(
    gates,
    ^

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions