diff --git a/.github/workflows/generate_website.yml b/.github/workflows/generate_website.yml index 2aa1432..115fb8e 100644 --- a/.github/workflows/generate_website.yml +++ b/.github/workflows/generate_website.yml @@ -4,6 +4,7 @@ on: push: branches: - main + pull_request: workflow_dispatch: permissions: @@ -32,10 +33,13 @@ jobs: - uses: julia-actions/julia-buildpkg@v1 - # This step sends model_keys and adtype_keys to GITHUB_OUTPUT + - uses: astral-sh/setup-uv@v5 + with: + python-version: "3.13" + - name: Setup keys id: keys - run: ./ad.sh setup + run: uv run ad.py setup run-models: runs-on: ubuntu-latest @@ -61,9 +65,13 @@ jobs: - uses: julia-actions/julia-buildpkg@v1 - - name: Run AD + - uses: astral-sh/setup-uv@v5 + with: + python-version: "3.13" + + - name: Run given model with all adtypes id: run - run: ./ad.sh run-model ${{ matrix.model }} + run: uv run ad.py run --model ${{ matrix.model }} env: ADTYPE_KEYS: ${{ needs.setup-keys.outputs.adtype_keys }} @@ -77,6 +85,7 @@ jobs: collect-results: runs-on: ubuntu-latest + if: github.event_name != 'pull_request' needs: run-models steps: @@ -86,9 +95,7 @@ jobs: with: python-version: "3.13" - - run: | - uv python install - uv run collate.py + - run: uv run ad.py html env: RESULTS_JSON: ${{ needs.run-models.outputs.json }} diff --git a/ad.py b/ad.py new file mode 100644 index 0000000..266a491 --- /dev/null +++ b/ad.py @@ -0,0 +1,309 @@ +""" +ad.py +----- + +Top-level Python script which orchestrates the Julia AD tests. + +Usage: + + python ad.py setup + python ad.py run --model + python ad.py html +""" + +import json +import os +import subprocess as sp +import tomllib +import argparse +from pathlib import Path +from warnings import warn + +JULIA_COMMAND = ["julia", "--color=yes", "--project=.", "main.jl"] + +def run_and_capture(command): + """Run a command and capture its output.""" + result = sp.run(command, text=True, check=True, stdout=sp.PIPE) + return result.stdout.strip() + +def append_to_github_output(key, value): + """Append a key-value pair to the file specified by $GITHUB_OUTPUT.""" + pair = f"{key}={json.dumps(value)}" + try: + fname = os.environ["GITHUB_OUTPUT"] + with open(fname, "a") as f: + print(pair, file=f) + except KeyError: + print(f"GITHUB_OUTPUT not set") + print(pair) + +def setup(_args): + models = run_and_capture([*JULIA_COMMAND, "--list-model-keys"]).splitlines() + adtypes = run_and_capture([*JULIA_COMMAND, "--list-adtype-keys"]).splitlines() + append_to_github_output("model_keys", models) + append_to_github_output("adtype_keys", adtypes) + # TODO: Save the Manifest.toml file or at least a mapping of packages -> + # versions, see #9 + +def run_ad(args): + model_key = args.model + + # Get adtypes + try: + adtypes = json.loads(os.environ["ADTYPE_KEYS"]) + except KeyError: + warn("ADTYPE_KEYS environment variable not set; running Julia to get adtypes") + adtypes = run_and_capture([*JULIA_COMMAND, "--list-adtype-keys"]).splitlines() + + results = {} + + # Run tests + for adtype in adtypes: + print(f"Running {model_key} with {adtype}...") + try: + output = run_and_capture([*JULIA_COMMAND, "--run", model_key, adtype]) + result = output.splitlines()[-1] + except sp.CalledProcessError as e: + result = "error" + + print(f" ... {model_key} with {adtype} ==> {result}") + results[adtype] = result + + print(results) + + # Save results + append_to_github_output("results", results) + + +def html(_args): + ## Here you can register known errors that have been reported on GitHub / + ## have otherwise been documented. They will be turned into links in the table. + + ENZYME_RVS_ONE_PARAM = "https://github.com/EnzymeAD/Enzyme.jl/issues/2337" + ENZYME_FWD_BLAS = "https://github.com/EnzymeAD/Enzyme.jl/issues/1995" + KNOWN_ERRORS = { + ("assume_beta", "EnzymeReverse"): ENZYME_RVS_ONE_PARAM, + ("assume_dirichlet", "EnzymeReverse"): ENZYME_RVS_ONE_PARAM, + ("assume_lkjcholu", "EnzymeReverse"): ENZYME_RVS_ONE_PARAM, + ("assume_normal", "EnzymeReverse"): ENZYME_RVS_ONE_PARAM, + ("assume_wishart", "EnzymeReverse"): ENZYME_RVS_ONE_PARAM, + ("assume_mvnormal", "EnzymeForward"): ENZYME_FWD_BLAS, + ("assume_wishart", "EnzymeForward"): ENZYME_FWD_BLAS, + } + + results = os.environ.get("RESULTS_JSON", None) + + if results is None: + print("RESULTS_JSON not set") + exit(1) + else: + print("-------- $RESULTS_JSON --------") + print(results) + print("------------- END -------------") + # results is a list of dicts that looks something like this. + # [ + # {"model_name": "model1", + # "results": { + # "AD1": "result1", + # "AD2": "result2" + # } + # }, + # {"model_name": "model2", + # "results": { + # "AD1": "result3", + # "AD2": "result4" + # } + # } + # ] + # We do some processing to turn it into a dict of dicts + results = json.loads(results) + results = {entry["model_name"]: entry["results"] for entry in results} + + # You can also process this with pandas. I don't do that here because + # (1) extra dependency + # (2) df.to_html() doesn't have enough customisation for our purposes. + # + # import pandas as pd + # results_flattened = [ + # {"model_name": entry["model_name"], **entry["results"]} + # for entry in json.loads(results) + # ] + # df = pd.DataFrame.from_records(results_flattened) + + adtypes = sorted(list(results.values())[0].keys()) + models = sorted(results.keys()) + + # Create the directory if it doesn't exist + os.makedirs("html", exist_ok=True) + with open("html/index.html", "w") as f: + f.write( +""" + +Turing AD tests + + +
+

Turing AD tests

+ +

Turing.jl documentation | Turing.jl GitHub | Source code for these tests

+ +

This page is intended as a brief overview of how different AD backends +perform on a variety of Turing.jl models. +Note that the inclusion of any AD backend here does not imply an endorsement +from the Turing team; this table is purely for information. +

+ + + +

Results

+""") + + # Table header + f.write('') + f.write("") + f.write("") + for adtype in adtypes: + f.write(f"") + f.write("") + # Table body + for model_name in models: + ad_results = results[model_name] + f.write("\n") + f.write(f"") + for adtype in adtypes: + ad_result = ad_results[adtype] + try: + float(ad_result) + f.write(f'') + except ValueError: + # Not a float, embed the class into the html + error_url = KNOWN_ERRORS.get((model_name, adtype), None) + span = f'{ad_result}' + if error_url is not None: + span = f'(?) {span}' + f.write(f'') + f.write("") + f.write("\n
Model name \\ AD type{adtype}
{model_name}{ad_result}{span}
") + + with open("html/main.css", "w") as f: + f.write( +""" +@import url('https://fonts.googleapis.com/css2?family=Fira+Code:wght@300..700&family=Fira+Sans:ital,wght@0,100;0,200;0,300;0,400;0,500;0,600;0,700;0,800;0,900;1,100;1,200;1,300;1,400;1,500;1,600;1,700;1,800;1,900&display=swap'); +html { + font-family: "Fira Sans", sans-serif; + box-sizing: border-box; + font-size: 16px; + line-height: 1.6; + background-color: #f1f2e3; +} +*, *:before, *:after { + box-sizing: inherit; +} + +body { + display: flex; + align-items: center; + margin: 0px 0px 50px 0px; +} + +main { + margin: auto; + max-width: 1250px; +} + +table#results { + text-align: right; + border: 1px solid black; + border-collapse: collapse; +} + +td, th { + border: 1px solid black; + padding: 0px 10px; +} + +th { + background-color: #ececec; + text-align: right; +} + +td { + font-family: "Fira Code", monospace; +} + +tr > td:first-child { + font-family: "Fira Sans", sans-serif; + font-weight: 700; + background-color: #ececec; +} + +tr > th:first-child { + font-family: "Fira Sans", sans-serif; + font-weight: 700; + background-color: #d1d1d1; +} + +span.err, span.error { + color: #ff0000; +} + +span.incorrect, span.wrong { + color: #ff0000; + background-color: #ffcccc; +} + +a.issue { + color: #880000; + text-decoration: none; +} + +a.issue:hover { + background-color: #ffcccc; + transition: background-color 0.3s ease; +} + +a.issue:visited { + color: #880000; +} +""") + + +def parse_arguments(): + parser = argparse.ArgumentParser(description="Run AD tests") + subparsers = parser.add_subparsers(required=True) + + # Setup + parser_setup = subparsers.add_parser("setup", help="Setup by saving model keys, adtype keys, and Manifest") + parser_setup.set_defaults(func=setup) + + # Run a given model with all adtypes + parser_run = subparsers.add_parser("run", help="Run a given model with all adtypes") + parser_run.add_argument( + "--model", type=str, help="Key of the model to run" + ) + parser_run.set_defaults(func=run_ad) + + # Generate HTML page + parser_html = subparsers.add_parser("html", help="Generate HTML page") + parser_html.set_defaults(func=html) + + return parser.parse_args() + + +if __name__ == "__main__": + args = parse_arguments() + args.func(args) diff --git a/ad.sh b/ad.sh deleted file mode 100755 index 412c55e..0000000 --- a/ad.sh +++ /dev/null @@ -1,85 +0,0 @@ -JULIA_COMMAND=("julia" "--color=yes" "--project=." "main.jl") - -if [ ! -f "main.jl" ]; then - echo "Could not find main.jl in the current directory. Please run this script from the directory containing main.jl." - exit 1 -fi - -if [ -z "$GITHUB_OUTPUT" ]; then - GITHUB_OUTPUT=$(mktemp) - echo "Not running on GitHub; using temporary file (${GITHUB_OUTPUT}) for output" -fi - -usage () { - echo "Usage: $0 [setup|run ]" - echo " setup: store available model keys and AD types in GITHUB_OUTPUT" - echo " run-model : Run the given model with all AD types" - exit 1 -} - -show_output () { - echo "----- \$GITHUB_OUTPUT CONTENTS BEGIN -----" - cat $GITHUB_OUTPUT - echo "----- \$GITHUB_OUTPUT CONTENTS END -----" -} - -setup () { - echo "Getting model keys..." - readarray -t MODELS < <(${JULIA_COMMAND[@]} --list-model-keys) - MODELS_JSON=$(jq -c -n '$ARGS.positional' --args ${MODELS[@]}) - echo "model_keys=${MODELS_JSON}" >> "${GITHUB_OUTPUT}" - - echo "Getting adtype keys..." - readarray -t ADTYPES < <(${JULIA_COMMAND[@]} --list-adtype-keys) - ADTYPE_JSON=$(jq -c -n '$ARGS.positional' --args ${ADTYPES[@]}) - echo "adtype_keys=${ADTYPE_JSON}" >> "${GITHUB_OUTPUT}" - - show_output -} - -# check if script is called with setup option -if [ "$1" == "setup" ]; then - setup - exit 0 -elif [ "$1" == "run-model" ]; then - if [ -z "$2" ]; then - usage - fi - MODEL_KEY=$2 - - if [ -z "$ADTYPE_KEYS" ]; then - echo "ADTYPE_KEYS is not set" - exit 1 - fi - readarray -t ADTYPES < <(echo $ADTYPE_KEYS | jq -r '.[]') - - declare -A RESULTS - - # run the model with the specified key - for ADTYPE in "${ADTYPES[@]}"; do - echo "Running ${MODEL_KEY} with ${ADTYPE}... " - OUTPUT=$(timeout 5m ${JULIA_COMMAND[@]} --run "${MODEL_KEY}" "${ADTYPE}") - if [ $? -eq 0 ]; then - RESULT=$(echo "${OUTPUT}" | tail -n 1) - else - RESULT="error" - fi - echo " ... ${MODEL_KEY} with ${ADTYPE} ==> ${RESULT}" - RESULTS["${ADTYPE}"]="${RESULT}" - done - - # Convert the associative array to JSON representation - RESULTS_JSON=$(for i in "${!RESULTS[@]}" - do - echo "$i" - echo "${RESULTS[$i]}" - done | - jq -c -n -R 'reduce inputs as $i ({}; . + { ($i): (input) })' - ) - - echo "results=${RESULTS_JSON}" >> "${GITHUB_OUTPUT}" - - show_output -else - usage -fi diff --git a/collate.py b/collate.py deleted file mode 100644 index 25680b4..0000000 --- a/collate.py +++ /dev/null @@ -1,207 +0,0 @@ -import json -import os - -## Here you can register known errors that have been reported on GitHub / -## have otherwise been documented. They will be turned into links in the table. - -ENZYME_RVS_ONE_PARAM = "https://github.com/EnzymeAD/Enzyme.jl/issues/2337" -ENZYME_FWD_BLAS = "https://github.com/EnzymeAD/Enzyme.jl/issues/1995" -KNOWN_ERRORS = { - ("assume_beta", "EnzymeReverse"): ENZYME_RVS_ONE_PARAM, - ("assume_dirichlet", "EnzymeReverse"): ENZYME_RVS_ONE_PARAM, - ("assume_lkjcholu", "EnzymeReverse"): ENZYME_RVS_ONE_PARAM, - ("assume_normal", "EnzymeReverse"): ENZYME_RVS_ONE_PARAM, - ("assume_wishart", "EnzymeReverse"): ENZYME_RVS_ONE_PARAM, - ("assume_mvnormal", "EnzymeForward"): ENZYME_FWD_BLAS, - ("assume_wishart", "EnzymeForward"): ENZYME_FWD_BLAS, -} - -results = os.environ.get("RESULTS_JSON", None) - -if results is None: - print("RESULTS_JSON not set") - exit(1) -else: - print("-------- $RESULTS_JSON --------") - print(results) - print("------------- END -------------") - # results is a list of dicts that looks something like this. - # [ - # {"model_name": "model1", - # "results": { - # "AD1": "result1", - # "AD2": "result2" - # } - # }, - # {"model_name": "model2", - # "results": { - # "AD1": "result3", - # "AD2": "result4" - # } - # } - # ] - # We do some processing to turn it into a dict of dicts - results = json.loads(results) - results = {entry["model_name"]: entry["results"] for entry in results} - -# You can also process this with pandas. I don't do that here because -# (1) extra dependency -# (2) df.to_html() doesn't have enough customisation for our purposes. -# -# import pandas as pd -# results_flattened = [ -# {"model_name": entry["model_name"], **entry["results"]} -# for entry in json.loads(results) -# ] -# df = pd.DataFrame.from_records(results_flattened) - -adtypes = sorted(list(results.values())[0].keys()) -models = sorted(results.keys()) - -# Create the directory if it doesn't exist -os.makedirs("html", exist_ok=True) -with open("html/index.html", "w") as f: - f.write( -""" - -Turing AD tests - - -
-

Turing AD tests

- -

Turing.jl documentation | Turing.jl GitHub | Source code for these tests

- -

This page is intended as a brief overview of how different AD backends -perform on a variety of Turing.jl models. -Note that the inclusion of any AD backend here does not imply an endorsement -from the Turing team; this table is purely for information. -

- -
    -
  • The definitions of the models and AD types below can be found on GitHub.
  • -
  • Numbers indicate the time taken to calculate the gradient of the log -density of the model using the specified AD type, divided by the time taken to -calculate the log density itself (in AD speak, the primal). Basically: -smaller means faster.
  • -
  • 'wrong' means that AD ran but the result was not -correct. If this happens you should be very wary! Note that this is done by -comparing against the result obtained using ForwardDiff, i.e., ForwardDiff is -by definition always 'correct'.
  • -
  • 'error' means that AD didn't run.
  • -
  • Some of the 'wrong' or 'error' entries have question marks next to them. -These will link to a GitHub issue or other page that describes the problem. -
- -

Results

-""") - - # Table header - f.write('') - f.write("") - f.write("") - for adtype in adtypes: - f.write(f"") - f.write("") - # Table body - for model_name in models: - ad_results = results[model_name] - f.write("\n") - f.write(f"") - for adtype in adtypes: - ad_result = ad_results[adtype] - try: - float(ad_result) - f.write(f'') - except ValueError: - # Not a float, embed the class into the html - error_url = KNOWN_ERRORS.get((model_name, adtype), None) - span = f'{ad_result}' - if error_url is not None: - span = f'(?) {span}' - f.write(f'') - f.write("") - f.write("\n
Model name \ AD type{adtype}
{model_name}{ad_result}{span}
") - -with open("html/main.css", "w") as f: - f.write( -""" -@import url('https://fonts.googleapis.com/css2?family=Fira+Code:wght@300..700&family=Fira+Sans:ital,wght@0,100;0,200;0,300;0,400;0,500;0,600;0,700;0,800;0,900;1,100;1,200;1,300;1,400;1,500;1,600;1,700;1,800;1,900&display=swap'); -html { - font-family: "Fira Sans", sans-serif; - box-sizing: border-box; - font-size: 16px; - line-height: 1.6; - background-color: #f1f2e3; -} -*, *:before, *:after { - box-sizing: inherit; -} - -body { - display: flex; - align-items: center; - margin: 0px 0px 50px 0px; -} - -main { - margin: auto; - max-width: 1250px; -} - -table#results { - text-align: right; - border: 1px solid black; - border-collapse: collapse; -} - -td, th { - border: 1px solid black; - padding: 0px 10px; -} - -th { - background-color: #ececec; - text-align: right; -} - -td { - font-family: "Fira Code", monospace; -} - -tr > td:first-child { - font-family: "Fira Sans", sans-serif; - font-weight: 700; - background-color: #ececec; -} - -tr > th:first-child { - font-family: "Fira Sans", sans-serif; - font-weight: 700; - background-color: #d1d1d1; -} - -span.err, span.error { - color: #ff0000; -} - -span.incorrect, span.wrong { - color: #ff0000; - background-color: #ffcccc; -} - -a.issue { - color: #880000; - text-decoration: none; -} - -a.issue:hover { - background-color: #ffcccc; - transition: background-color 0.3s ease; -} - -a.issue:visited { - color: #880000; -} -""")