Skip to content

Commit 59d489f

Browse files
authored
Add tests
1 parent 77fcd3f commit 59d489f

File tree

3 files changed

+53
-1
lines changed

3 files changed

+53
-1
lines changed

.github/workflows/main.yml

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
name: CI
2+
3+
on:
4+
push:
5+
branches: [ master ]
6+
pull_request:
7+
8+
jobs:
9+
run_linux:
10+
name: Test on Linux
11+
runs-on: ubuntu-18.04
12+
13+
steps:
14+
- uses: actions/checkout@v2
15+
16+
- name: Install
17+
run: |
18+
sudo apt-get install -y python3-setuptools
19+
python3 -m pip install --upgrade pip
20+
python3 -m pip install -r requirements.txt
21+
sudo python3 setup.py develop
22+
23+
- name: Download data and models
24+
run: |
25+
bonito download --models
26+
mkdir reads
27+
sudo apt-get install -y megatools
28+
megadl 'https://mega.nz/#!dccAET7R!-zq6ECPCzaN5jLjh8ASKRJdcjt-1325pEAtuIcmBGsg' --no-progress
29+
unzip reads.zip -d reads
30+
31+
- name: Test dna_r9.4.1@v2
32+
run: |
33+
bonito basecaller dna_r9.4.1@v2 --device=cpu reads > ref.fasta
34+
bonito basecaller dna_r9.4.1@v2 --use_openvino --device=cpu reads > out.fasta
35+
cat ref.fasta
36+
cat out.fasta
37+
cmp ref.fasta out.fasta
38+
39+
- name: Test dna_r10.3@v3
40+
run: |
41+
bonito basecaller dna_r10.3@v3 --device=cpu reads > ref.fasta
42+
bonito basecaller dna_r10.3@v3 --use_openvino --device=cpu reads > out.fasta
43+
cat ref.fasta
44+
cat out.fasta
45+
cmp ref.fasta out.fasta

bonito/crf/basecall.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,12 @@ def transfer(x):
6060
"""
6161
Device to host transfer using pinned memory.
6262
"""
63+
if not torch.cuda.is_available():
64+
return {
65+
k: torch.empty(v.shape, pin_memory=False, dtype=v.dtype).copy_(v).numpy()
66+
for k, v in x.items()
67+
}
68+
6369
torch.cuda.synchronize()
6470
with torch.cuda.stream(torch.cuda.Stream()):
6571
return {

bonito/ctc/basecall.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,8 @@ def compute_scores(model, batch):
3737
"""
3838
with torch.no_grad():
3939
device = next(model.parameters()).device
40-
chunks = batch.to(torch.half).to(device)
40+
chunks = batch.to(torch.half) if half_supported() else batch
41+
chunks = chunks.to(device)
4142
probs = permute(model(chunks), 'TNC', 'NTC')
4243
return probs.cpu().to(torch.float32)
4344

0 commit comments

Comments
 (0)