Skip to content

Commit a06c916

Browse files
authored
Merge branch 'main' into gh/swolchok/517/head
2 parents 995f6fa + a624083 commit a06c916

File tree

4 files changed

+32
-39
lines changed

4 files changed

+32
-39
lines changed

.github/workflows/pull.yml

Lines changed: 27 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -632,32 +632,33 @@ jobs:
632632
# run eval_llama wikitext task
633633
PYTHON_EXECUTABLE=python bash .ci/scripts/test_eval_llama_wikitext.sh
634634
635-
test-eval_llama-mmlu-linux:
636-
name: test-eval_llama-mmlu-linux
637-
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
638-
permissions:
639-
id-token: write
640-
contents: read
641-
strategy:
642-
fail-fast: false
643-
with:
644-
runner: linux.24xlarge
645-
docker-image: ci-image:executorch-ubuntu-22.04-clang12
646-
submodules: 'recursive'
647-
ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}
648-
timeout: 90
649-
script: |
650-
# The generic Linux job chooses to use base env, not the one setup by the image
651-
CONDA_ENV=$(conda env list --json | jq -r ".envs | .[-1]")
652-
conda activate "${CONDA_ENV}"
653-
654-
PYTHON_EXECUTABLE=python bash .ci/scripts/setup-linux.sh --build-tool "cmake"
655-
656-
# install llama requirements
657-
bash examples/models/llama/install_requirements.sh
658-
659-
# run eval_llama mmlu task
660-
PYTHON_EXECUTABLE=python bash .ci/scripts/test_eval_llama_mmlu.sh
635+
# TODO(larryliu0820): Fix this issue before reenabling it: https://gist.github.com/larryliu0820/7377ecd0d79dbc06076cec8d9f2b85d2
636+
# test-eval_llama-mmlu-linux:
637+
# name: test-eval_llama-mmlu-linux
638+
# uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
639+
# permissions:
640+
# id-token: write
641+
# contents: read
642+
# strategy:
643+
# fail-fast: false
644+
# with:
645+
# runner: linux.24xlarge
646+
# docker-image: ci-image:executorch-ubuntu-22.04-clang12
647+
# submodules: 'recursive'
648+
# ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}
649+
# timeout: 90
650+
# script: |
651+
# # The generic Linux job chooses to use base env, not the one setup by the image
652+
# CONDA_ENV=$(conda env list --json | jq -r ".envs | .[-1]")
653+
# conda activate "${CONDA_ENV}"
654+
655+
# PYTHON_EXECUTABLE=python bash .ci/scripts/setup-linux.sh --build-tool "cmake"
656+
657+
# # install llama requirements
658+
# bash examples/models/llama/install_requirements.sh
659+
660+
# # run eval_llama mmlu task
661+
# PYTHON_EXECUTABLE=python bash .ci/scripts/test_eval_llama_mmlu.sh
661662

662663
test-llama_runner_eager-linux:
663664
name: test-llama_runner_eager-linux

exir/backend/test/test_backends.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1033,7 +1033,7 @@ def false_fn(x, y):
10331033

10341034
def f(x, y):
10351035
x = x + y
1036-
x = torch.ops.higher_order.cond(x[0][0] == 1, true_fn, false_fn, [x, y])
1036+
x = torch.cond(x[0][0] == 1, true_fn, false_fn, [x, y])
10371037
x = x - y
10381038
return x
10391039

exir/tests/control_flow_models.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,7 @@ def true_branch(x):
2020
def false_branch(x):
2121
return x * x
2222

23-
return torch.ops.higher_order.cond(
24-
inp.sum() > 4, true_branch, false_branch, [inp]
25-
)
23+
return torch.cond(inp.sum() > 4, true_branch, false_branch, [inp])
2624

2725
def get_random_inputs(self):
2826
return (torch.rand(5),)
@@ -39,9 +37,7 @@ def true_branch(x):
3937
def false_branch(x):
4038
return x * x * x
4139

42-
return torch.ops.higher_order.cond(
43-
inp.sum() > 4, true_branch, false_branch, [inp]
44-
)
40+
return torch.cond(inp.sum() > 4, true_branch, false_branch, [inp])
4541

4642
def get_upper_bound_inputs(self):
4743
return (torch.rand(8),)
@@ -72,9 +68,7 @@ def true_branch(x):
7268
def false_branch(x):
7369
return x * 2
7470

75-
return torch.ops.higher_order.cond(
76-
inp.sum() > 4, true_branch, false_branch, [inp]
77-
)
71+
return torch.cond(inp.sum() > 4, true_branch, false_branch, [inp])
7872

7973
def get_random_inputs(self):
8074
return (torch.eye(5) * 2,)

exir/tests/test_passes.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1463,9 +1463,7 @@ def forward(self, pred, x):
14631463
out = torch.nn.functional.linear(
14641464
x, self.w.to(torch.float16).to(torch.float32)
14651465
)
1466-
return torch.ops.higher_order.cond(
1467-
pred, self.true_fn, self.false_fn, [out]
1468-
)
1466+
return torch.cond(pred, self.true_fn, self.false_fn, [out])
14691467

14701468
mod = Module()
14711469
x = torch.randn([3, 3])

0 commit comments

Comments
 (0)