Skip to content
Merged
Show file tree
Hide file tree
Changes from 107 commits
Commits
Show all changes
112 commits
Select commit Hold shift + click to select a range
77b3f02
Updates to 'language' support
chrisrichardson Apr 30, 2024
620321a
Fixes for ruff
chrisrichardson Apr 30, 2024
c8ac53a
Add numba
chrisrichardson Apr 30, 2024
4b3d85c
Add suffix support
chrisrichardson May 7, 2024
c29a1f9
Merge branch 'main' into chris/languages-3
chrisrichardson May 7, 2024
0d4d4b4
updates
chrisrichardson May 7, 2024
f316110
Formatting fixes
chrisrichardson May 7, 2024
018bc9e
Merge branch 'main' into chris/languages-3
chrisrichardson May 28, 2024
584d648
Merge branch 'main' into chris/languages-3
chrisrichardson May 30, 2024
7052541
fixes
chrisrichardson May 30, 2024
aacc715
Remove wrapper
chrisrichardson May 30, 2024
2134b36
Merge branch 'main' into schnellerhase/copy-chris/languages-3
schnellerhase Nov 26, 2025
9c2ad9d
Fix
schnellerhase Nov 26, 2025
2f3c3d5
[tmp] no mypy
schnellerhase Nov 26, 2025
a41f7d7
[tmp] no ruff
schnellerhase Nov 26, 2025
8b73bc3
Start with numba
schnellerhase Nov 26, 2025
218ce8d
drop cpp/
schnellerhase Nov 26, 2025
1c3e77d
ruff
schnellerhase Nov 26, 2025
6bacc68
Merge branch 'main' into schnellerhase/copy-chris/languages-3
schnellerhase Nov 26, 2025
a2079c3
add __all__
schnellerhase Nov 26, 2025
d651099
Test numba standalone
schnellerhase Nov 26, 2025
5b3c7ad
Get CI up
schnellerhase Nov 26, 2025
3363324
dependency numba
schnellerhase Nov 26, 2025
b98b945
Explicit
schnellerhase Nov 26, 2025
4724e1a
install numba module
schnellerhase Nov 26, 2025
1543c5e
parallel test execution
schnellerhase Nov 26, 2025
1420eb1
Fix complex
schnellerhase Nov 26, 2025
96eafee
Make optional
schnellerhase Nov 26, 2025
c4aeb22
Skip real only demos
schnellerhase Nov 26, 2025
bd1d933
Reactivate ruff
schnellerhase Nov 26, 2025
f85afda
Start on unit testing (almost passing)
schnellerhase Nov 27, 2025
655d92d
Commit ufl, not kernel
schnellerhase Nov 27, 2025
f7d308b
Add custom_data input
schnellerhase Nov 27, 2025
09a0576
Fix n_const computation
schnellerhase Nov 27, 2025
5b9253a
Add array creation by (full) scalar value
schnellerhase Nov 27, 2025
be5f6b5
n_coeff computation
schnellerhase Nov 27, 2025
e02c4af
Linear form + coefficient kernel
schnellerhase Nov 27, 2025
6233d19
Introduce number_coordinate_dofs to CommonExpressionIR
schnellerhase Nov 27, 2025
67c1522
Fix remaining shape computations
schnellerhase Nov 27, 2025
e0fc385
Fix no coordinate element case
schnellerhase Nov 27, 2025
9ed41a1
Parametrize over scalar type
schnellerhase Nov 27, 2025
0761b70
remove in pwd
schnellerhase Nov 27, 2025
b146cfa
Fix expressions template
schnellerhase Nov 27, 2025
8576e6b
Add expression tensor size computations
schnellerhase Nov 27, 2025
d3e1d35
Add expression test
schnellerhase Nov 27, 2025
b4f01de
Merge branch 'main' into schnellerhase/copy-chris/languages-3
schnellerhase Nov 27, 2025
dd0f073
Start fine tuning
schnellerhase Nov 27, 2025
e6333eb
modernize cmdline test
schnellerhase Nov 27, 2025
8730fa5
Merge with poisson test
schnellerhase Nov 27, 2025
1bfe332
Licensing headers
schnellerhase Nov 27, 2025
453d15f
Resolve path
schnellerhase Nov 27, 2025
70552cf
Fix files system mess
schnellerhase Nov 27, 2025
4e3b2bd
Language agnostic formatter name: c_format -> format
schnellerhase Nov 27, 2025
098328d
Language agnostic formatter name: c_impl -> impl
schnellerhase Nov 27, 2025
1fc4944
Language agnostic naming: c/numba_implementation -> implementation
schnellerhase Nov 27, 2025
2774e7c
Language agnostic naming: c/numbaFormatter -> Formatter
schnellerhase Nov 27, 2025
ec1f538
One more visualise fix
schnellerhase Nov 27, 2025
77280a8
Prepare alignment of expressions to C
schnellerhase Nov 27, 2025
77ac932
Integrals alost completely aligned
schnellerhase Nov 27, 2025
0bff1fc
Finalise cleanup
schnellerhase Nov 27, 2025
07a7e14
checked
schnellerhase Nov 27, 2025
f6be02b
Reactivate mypy
schnellerhase Nov 27, 2025
664bb1f
Fix mypy
schnellerhase Nov 27, 2025
161ad40
Fix: tensor size
schnellerhase Nov 28, 2025
ffdd3c6
Activate redundant check, drop uneccessary args
schnellerhase Nov 28, 2025
cab8da1
Add coordinate_element_hash to expression
schnellerhase Nov 28, 2025
48bd21f
Extend keys check and tidy of integrals
schnellerhase Nov 28, 2025
021a603
Tidy up test_demos and allow for further extensions
schnellerhase Nov 28, 2025
5249124
Add key checking to form
schnellerhase Nov 28, 2025
ca6a9e8
Try with Path
schnellerhase Nov 28, 2025
a45b464
Add choices
schnellerhase Nov 30, 2025
8546437
only import basix.ufl
schnellerhase Nov 30, 2025
d234b84
type hints
schnellerhase Nov 30, 2025
4dc78af
fixes
schnellerhase Nov 30, 2025
fa45531
no self
schnellerhase Nov 30, 2025
ff6498c
derived
schnellerhase Nov 30, 2025
3eeb26d
more
schnellerhase Nov 30, 2025
fcb5636
.
schnellerhase Nov 30, 2025
43c46d2
+1
schnellerhase Nov 30, 2025
2b2fab4
format
schnellerhase Nov 30, 2025
32e5f66
race condition on windows?
schnellerhase Dec 1, 2025
e196d86
No subprocess for FFCx call - should result in accurate coverage reports
schnellerhase Dec 6, 2025
c808bfd
ruff
schnellerhase Dec 6, 2025
880c448
Revert for demos, cwd more important than coverage
schnellerhase Dec 6, 2025
ef0e3a2
Update test/poisson.py
schnellerhase Dec 7, 2025
a6b4113
Drop user defined import
schnellerhase Dec 7, 2025
e676680
Move --language option to options.py
schnellerhase Dec 7, 2025
d8f317a
Fix: size of tensor, extend test to tensor valued expression
schnellerhase Dec 7, 2025
9ce56a1
off by one
schnellerhase Dec 7, 2025
80b53d6
int
schnellerhase Dec 7, 2025
c2c6da4
mypy
schnellerhase Dec 7, 2025
6b71383
Get type info from formatter - use _dtype_to_name logic
schnellerhase Dec 8, 2025
1251d16
mypy
schnellerhase Dec 8, 2025
7916b2d
ignore
schnellerhase Dec 8, 2025
8c5d7cd
format
schnellerhase Dec 8, 2025
611ebb9
Apply suggestions from code review
schnellerhase Dec 8, 2025
a5df6ab
Add docstrings and types to integrals
schnellerhase Dec 8, 2025
3faefbd
String/comment formatting
schnellerhase Dec 8, 2025
d662931
mypy
schnellerhase Dec 8, 2025
3ccd079
Raise bessel
schnellerhase Dec 8, 2025
eefeb6c
Introduce _format_comment_str
schnellerhase Dec 8, 2025
fe3b1de
docstring
schnellerhase Dec 8, 2025
5476288
Docstring
schnellerhase Dec 8, 2025
ec82f17
We do need bessel :)
schnellerhase Dec 8, 2025
2921d32
Remove API breaking chane with suffixes
schnellerhase Dec 8, 2025
d6a5996
Tidy
schnellerhase Dec 8, 2025
bdf95cc
last one?
schnellerhase Dec 8, 2025
6061a88
Apply suggestions from code review
schnellerhase Dec 8, 2025
28a47a6
Consistent factory_name with C
schnellerhase Dec 8, 2025
38edf48
Template of form was unallgined with C, rewrite
schnellerhase Dec 8, 2025
d46d49f
NULL -> None
schnellerhase Dec 8, 2025
af5078d
Add warning to file template
schnellerhase Dec 9, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .github/workflows/pythonapp.yml
Original file line number Diff line number Diff line change
Expand Up @@ -122,8 +122,10 @@ jobs:
uses: ilammy/msvc-dev-cmd@v1

- name: Run FFCx demos
run: |
run: >
pytest demo/test_demos.py
-W error
# -n auto

- name: Build documentation
run: |
Expand Down
2 changes: 1 addition & 1 deletion demo/MassAction.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
"""Mass action demo."""

import basix
import basix.ufl
import ufl

P = 3
Expand Down
93 changes: 60 additions & 33 deletions demo/test_demos.py
Original file line number Diff line number Diff line change
@@ -1,60 +1,87 @@
"""Test demos."""

import os
import subprocess
import sys
from pathlib import Path

import pytest

demo_dir = os.path.dirname(os.path.realpath(__file__))
demo_dir = Path(__file__).parent

ufl_files = []
for file in os.listdir(demo_dir):
if file.endswith(".py") and not file == "test_demos.py":
ufl_files.append(file[:-3])
ufl_files = [
f
for f in demo_dir.iterdir()
if f.suffix == ".py" and not f.stem.endswith("_numba") and f != Path(__file__)
]

skip_complex = ["BiharmonicHHJ", "BiharmonicRegge", "StabilisedStokes"]


def skip_unsupported(test):
"""Dcecorate test case to skip unsupported cases."""

def check_skip(file, scalar_type):
"""Skip scalar_type file combinations not supported."""
if "complex" in scalar_type and file.stem in skip_complex:
pytest.skip(reason="Not implemented for complex types")
elif "Complex" in file.stem and scalar_type in ["float64", "float32"]:
pytest.skip(reason="Not implemented for real types")

return test(file, scalar_type)

return check_skip


@pytest.mark.parametrize("file", ufl_files)
@pytest.mark.parametrize("scalar_type", ["float64", "float32", "complex128", "complex64"])
def test_demo(file, scalar_type):
@skip_unsupported
def test_C(file, scalar_type):
"""Test a demo."""
if sys.platform.startswith("win32") and "complex" in scalar_type:
# Skip complex demos on win32
pytest.skip(reason="_Complex not supported on Windows")

if "complex" in scalar_type and file in [
"BiharmonicHHJ",
"BiharmonicRegge",
"StabilisedStokes",
]:
# Skip demos that are not implemented for complex scalars
pytest.skip(reason="Not implemented for complex types")
elif "Complex" in file and scalar_type in ["float64", "float32"]:
# Skip demos that are only implemented for complex scalars
pytest.skip(reason="Not implemented for real types")
subprocess.run(["ffcx", "--scalar_type", scalar_type, file], cwd=demo_dir, check=True)

if sys.platform.startswith("win32"):
opts = f"--scalar_type {scalar_type}"
extra_flags = "/std:c17"
assert os.system(f"cd {demo_dir} && ffcx {opts} {file}.py") == 0
assert (
os.system(
f'cd {demo_dir} && cl.exe /I "../ffcx/codegeneration" {extra_flags} /c {file}.c'
for compiler in ["cl.exe", "clang-cl.exe"]:
subprocess.run(
[
compiler,
"/I",
f"{demo_dir.parent / 'ffcx/codegeneration'}",
*extra_flags.split(" "),
"/c",
file.with_suffix(".c"),
],
cwd=demo_dir,
check=True,
)
) == 0
assert (
os.system(
f"cd {demo_dir} && "
f'clang-cl.exe /I "../ffcx/codegeneration" {extra_flags} /c {file}.c'
)
) == 0
else:
cc = os.environ.get("CC", "cc")
opts = f"--scalar_type {scalar_type}"
extra_flags = (
"-std=c17 -Wunused-variable -Werror -fPIC -Wno-error=implicit-function-declaration"
)
assert os.system(f"cd {demo_dir} && ffcx {opts} {file}.py") == 0
assert (
os.system(f"cd {demo_dir} && {cc} -I../ffcx/codegeneration {extra_flags} -c {file}.c")
== 0
subprocess.run(
[
cc,
f"-I{demo_dir.parent / 'ffcx/codegeneration'}",
*extra_flags.split(" "),
"-c",
file.with_suffix(".c"),
],
cwd=demo_dir,
check=True,
)


@pytest.mark.parametrize("file", ufl_files)
@pytest.mark.parametrize("scalar_type", ["float64", "float32", "complex128", "complex64"])
@skip_unsupported
def test_numba(file, scalar_type):
"""Test numba generation."""
opts = f"--language numba --scalar_type {scalar_type}"
subprocess.run(["ffcx", *opts.split(" "), file], cwd=demo_dir, check=True)
subprocess.run(["python", file], cwd=demo_dir, check=True)
4 changes: 4 additions & 0 deletions ffcx/codegeneration/C/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,5 @@
"""Generation of C code."""

from ffcx.codegeneration.C import expressions, file, form, integrals

__all__ = ["expressions", "file", "form", "integrals", "suffixes"]
12 changes: 7 additions & 5 deletions ffcx/codegeneration/C/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,13 @@
from __future__ import annotations

import logging
import string

import numpy as np

from ffcx.codegeneration.backend import FFCXBackend
from ffcx.codegeneration.C import expressions_template
from ffcx.codegeneration.C.c_implementation import CFormatter
from ffcx.codegeneration.C.implementation import Formatter
from ffcx.codegeneration.expression_generator import ExpressionGenerator
from ffcx.codegeneration.utils import dtype_to_c_type, dtype_to_scalar_dtype
from ffcx.ir.representation import ExpressionIR
Expand Down Expand Up @@ -43,8 +44,8 @@ def generator(ir: ExpressionIR, options):
d["factory_name"] = factory_name
parts = eg.generate()

CF = CFormatter(options["scalar_type"])
d["tabulate_expression"] = CF.c_format(parts)
CF = Formatter(options["scalar_type"])
d["tabulate_expression"] = CF.format(parts)

if len(ir.original_coefficient_positions) > 0:
d["original_coefficient_positions"] = f"original_coefficient_positions_{factory_name}"
Expand Down Expand Up @@ -106,9 +107,10 @@ def generator(ir: ExpressionIR, options):
d["coordinate_element_hash"] = f"UINT64_C({ir.expression.coordinate_element_hash})"

# Check that no keys are redundant or have been missed
from string import Formatter

fields = [fname for _, fname, _, _ in Formatter().parse(expressions_template.factory) if fname]
fields = [
fname for _, fname, _, _ in string.Formatter().parse(expressions_template.factory) if fname
]
assert set(fields) == set(d.keys()), "Mismatch between keys in template and in formatting dict"

# Format implementation code
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@
}


class CFormatter:
class Formatter:
"""C formatter."""

scalar_type: np.dtype
Expand Down Expand Up @@ -186,7 +186,7 @@ def _build_initializer_lists(self, values):

def format_statement_list(self, slist) -> str:
"""Format a statement list."""
return "".join(self.c_format(s) for s in slist.statements)
return "".join(self.format(s) for s in slist.statements)

def format_section(self, section) -> str:
"""Format a section."""
Expand All @@ -197,12 +197,12 @@ def format_section(self, section) -> str:
f"// Inputs: {', '.join(w.name for w in section.input)}\n"
f"// Outputs: {', '.join(w.name for w in section.output)}\n"
)
declarations = "".join(self.c_format(s) for s in section.declarations)
declarations = "".join(self.format(s) for s in section.declarations)

body = ""
if len(section.statements) > 0:
declarations += "{\n "
body = "".join(self.c_format(s) for s in section.statements)
body = "".join(self.format(s) for s in section.statements)
body = body.replace("\n", "\n ")
body = body[:-2] + "}\n"

Expand All @@ -218,7 +218,7 @@ def format_array_decl(self, arr) -> str:
dtype = arr.symbol.dtype
typename = self._dtype_to_name(dtype)

symbol = self.c_format(arr.symbol)
symbol = self.format(arr.symbol)
dims = "".join([f"[{i}]" for i in arr.sizes])
if arr.values is None:
assert arr.const is False
Expand All @@ -230,21 +230,21 @@ def format_array_decl(self, arr) -> str:

def format_array_access(self, arr) -> str:
"""Format an array access."""
name = self.c_format(arr.array)
indices = f"[{']['.join(self.c_format(i) for i in arr.indices)}]"
name = self.format(arr.array)
indices = f"[{']['.join(self.format(i) for i in arr.indices)}]"
return f"{name}{indices}"

def format_variable_decl(self, v) -> str:
"""Format a variable declaration."""
val = self.c_format(v.value)
symbol = self.c_format(v.symbol)
val = self.format(v.value)
symbol = self.format(v.symbol)
typename = self._dtype_to_name(v.symbol.dtype)
return f"{typename} {symbol} = {val};\n"

def format_nary_op(self, oper) -> str:
"""Format an n-ary operation."""
# Format children
args = [self.c_format(arg) for arg in oper.args]
args = [self.format(arg) for arg in oper.args]

# Apply parentheses
for i in range(len(args)):
Expand All @@ -257,8 +257,8 @@ def format_nary_op(self, oper) -> str:
def format_binary_op(self, oper) -> str:
"""Format a binary operation."""
# Format children
lhs = self.c_format(oper.lhs)
rhs = self.c_format(oper.rhs)
lhs = self.format(oper.lhs)
rhs = self.format(oper.rhs)

# Apply parentheses
if oper.lhs.precedence >= oper.precedence:
Expand All @@ -271,7 +271,7 @@ def format_binary_op(self, oper) -> str:

def format_unary_op(self, oper) -> str:
"""Format a unary operation."""
arg = self.c_format(oper.arg)
arg = self.format(oper.arg)
if oper.arg.precedence >= oper.precedence:
return f"{oper.op}({arg})"
return f"{oper.op}{arg}"
Expand All @@ -287,12 +287,12 @@ def format_literal_int(self, val) -> str:

def format_for_range(self, r) -> str:
"""Format a for loop over a range."""
begin = self.c_format(r.begin)
end = self.c_format(r.end)
index = self.c_format(r.index)
begin = self.format(r.begin)
end = self.format(r.end)
index = self.format(r.index)
output = f"for (int {index} = {begin}; {index} < {end}; ++{index})\n"
output += "{\n"
body = self.c_format(r.body)
body = self.format(r.body)
for line in body.split("\n"):
if len(line) > 0:
output += f" {line}\n"
Expand All @@ -301,20 +301,20 @@ def format_for_range(self, r) -> str:

def format_statement(self, s) -> str:
"""Format a statement."""
return self.c_format(s.expr)
return self.format(s.expr)

def format_assign(self, expr) -> str:
"""Format an assignment."""
rhs = self.c_format(expr.rhs)
lhs = self.c_format(expr.lhs)
rhs = self.format(expr.rhs)
lhs = self.format(expr.lhs)
return f"{lhs} {expr.op} {rhs};\n"

def format_conditional(self, s) -> str:
"""Format a conditional."""
# Format children
c = self.c_format(s.condition)
t = self.c_format(s.true)
f = self.c_format(s.false)
c = self.format(s.condition)
t = self.format(s.true)
f = self.format(s.false)

# Apply parentheses
if s.condition.precedence >= s.precedence:
Expand All @@ -333,7 +333,7 @@ def format_symbol(self, s) -> str:

def format_multi_index(self, mi) -> str:
"""Format a multi-index."""
return self.c_format(mi.global_index)
return self.format(mi.global_index)

def format_math_function(self, c) -> str:
"""Format a mathematical function."""
Expand All @@ -349,10 +349,10 @@ def format_math_function(self, c) -> str:

# Get a function from the table, if available, else just use bare name
func = dtype_math_table.get(c.function, c.function)
args = ", ".join(self.c_format(arg) for arg in c.args)
args = ", ".join(self.format(arg) for arg in c.args)
return f"{func}({args})"

c_impl = {
impl = {
"Section": format_section,
"StatementList": format_statement_list,
"Comment": format_comment,
Expand Down Expand Up @@ -387,10 +387,10 @@ def format_math_function(self, c) -> str:
"LT": format_binary_op,
}

def c_format(self, s) -> str:
def format(self, s) -> str:
"""Format as C."""
name = s.__class__.__name__
try:
return self.c_impl[name](self, s)
return self.impl[name](self, s)
except KeyError:
raise RuntimeError("Unknown statement: ", name)
Loading
Loading