Skip to content

Commit edfac24

Browse files
authored
Merge pull request #1318 from Trusted-AI/development_goturn
Add estimator for object tracker GOTURN in PyTorch
2 parents 7859037 + 6b2b910 commit edfac24

File tree

11 files changed

+1141
-1
lines changed

11 files changed

+1141
-1
lines changed

.github/actions/goturn/Dockerfile

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
# Get base from a pytorch image
2+
FROM pytorch/pytorch:1.6.0-cuda10.1-cudnn7-runtime
3+
4+
# Set to install things in non-interactive mode
5+
ENV DEBIAN_FRONTEND noninteractive
6+
7+
# Install system wide softwares
8+
RUN apt-get update \
9+
&& apt-get install -y \
10+
libgl1-mesa-glx \
11+
libx11-xcb1 \
12+
git \
13+
gcc \
14+
mono-mcs \
15+
libavcodec-extra \
16+
ffmpeg \
17+
curl \
18+
libsndfile-dev \
19+
libsndfile1 \
20+
&& apt-get install -y libsm6 libxext6 \
21+
&& apt-get install -y libxrender-dev \
22+
&& apt-get clean all \
23+
&& rm -r /var/lib/apt/lists/*
24+
25+
RUN /opt/conda/bin/conda install --yes \
26+
astropy \
27+
matplotlib \
28+
pandas \
29+
scikit-learn \
30+
scikit-image
31+
32+
# Install necessary libraries for goturn
33+
RUN pip install torch==1.4
34+
RUN pip install tensorflow==2.1.4
35+
RUN pip install torchaudio==0.5.0
36+
RUN pip install pytest
37+
RUN pip install numba
38+
RUN pip install scikit-learn==0.20
39+
RUN pip install pytest-cov
40+
RUN pip install gdown
41+
42+
RUN git clone https://github.com/nrupatunga/goturn-pytorch.git /tmp/goturn-pytorch
43+
RUN cd /tmp/goturn-pytorch && pip install -r requirements.txt
44+
45+
RUN pip install numpy==1.20.3
46+
47+
ENV PYTHONPATH "${PYTHONPATH}:/tmp/goturn-pytorch/src"
48+
ENV PYTHONPATH "${PYTHONPATH}:/tmp/goturn-pytorch/src/scripts"
49+
50+
RUN mkdir /tmp/goturn-pytorch/src/goturn/models/checkpoints
51+
RUN cd /tmp/goturn-pytorch/src/goturn/models/checkpoints && gdown https://drive.google.com/uc?id=1GouImhqpcoDtV_eLrD2wra-qr3vkAMY4

.github/actions/goturn/action.yml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
name: 'Test GOTURN'
2+
description: 'Run tests for GOTURN'
3+
runs:
4+
using: 'composite'
5+
steps:
6+
- run: $GITHUB_ACTION_PATH/run.sh
7+
shell: bash

.github/actions/goturn/run.sh

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
#!/bin/bash
2+
3+
exit_code=0
4+
5+
pytest --cov-report=xml --cov=art --cov-append -q -vv tests/estimators/object_tracking/test_pytorch_goturn.py --framework=pytorch --durations=0
6+
if [[ $? -ne 0 ]]; then exit_code=1; echo "Failed estimators/object_tracking/test_pytorch_goturn tests"; fi
7+
8+
exit ${exit_code}

.github/workflows/ci-goturn.yml

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
name: CI PyTorchGoturn
2+
on:
3+
# Run on manual trigger
4+
workflow_dispatch:
5+
6+
# Run on pull requests
7+
pull_request:
8+
paths-ignore:
9+
- '*.md'
10+
11+
# Run when pushing to main or dev branches
12+
push:
13+
branches:
14+
- main
15+
- dev*
16+
17+
# Run scheduled CI flow daily
18+
schedule:
19+
- cron: '0 8 * * 0'
20+
21+
jobs:
22+
test_pytorch_goturn:
23+
name: PyTorchGoturn
24+
runs-on: ubuntu-latest
25+
container: adversarialrobustnesstoolbox/art_testing_envs:goturn
26+
steps:
27+
- name: Checkout Repo
28+
uses: actions/[email protected]
29+
- name: Run Test Action
30+
uses: ./.github/actions/goturn
31+
- name: Upload coverage to Codecov
32+
uses: codecov/[email protected]

art/attacks/evasion/fast_gradient.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -472,7 +472,11 @@ def _apply_perturbation(
472472
batch = batch + perturbation_step
473473
if self.estimator.clip_values is not None:
474474
clip_min, clip_max = self.estimator.clip_values
475-
batch = np.clip(batch, clip_min, clip_max)
475+
if batch.dtype == np.object:
476+
for i_obj in range(batch.shape[0]):
477+
batch[i_obj] = np.clip(batch[i_obj], clip_min, clip_max)
478+
else:
479+
batch = np.clip(batch, clip_min, clip_max)
476480

477481
return batch
478482

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
"""
2+
Module containing estimators for object tracking.
3+
"""
4+
from art.estimators.object_tracking.pytorch_goturn import PyTorchGoturn
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
# MIT License
2+
#
3+
# Copyright (C) The Adversarial Robustness Toolbox (ART) Authors 2021
4+
#
5+
# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated
6+
# documentation files (the "Software"), to deal in the Software without restriction, including without limitation the
7+
# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit
8+
# persons to whom the Software is furnished to do so, subject to the following conditions:
9+
#
10+
# The above copyright notice and this permission notice shall be included in all copies or substantial portions of the
11+
# Software.
12+
#
13+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE
14+
# WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
15+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
16+
# TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
17+
# SOFTWARE.
18+
"""
19+
This module implements mixin abstract base class for all object trackers in ART.
20+
"""
21+
22+
from abc import ABC, abstractmethod
23+
24+
from art.estimators.estimator import BaseEstimator
25+
from art.estimators.classification.classifier import LossGradientsMixin
26+
27+
28+
class ObjectTrackerMixin(ABC):
29+
"""
30+
Mix-in Base class for ART object trackers.
31+
"""
32+
33+
@property
34+
@abstractmethod
35+
def native_label_is_pytorch_format(self) -> bool:
36+
"""
37+
Are the native labels in PyTorch format [x1, y1, x2, y2]?
38+
"""
39+
raise NotImplementedError
40+
41+
42+
class ObjectTracker(ObjectTrackerMixin, LossGradientsMixin, BaseEstimator, ABC):
43+
"""
44+
Typing variable definition.
45+
"""
46+
47+
pass

0 commit comments

Comments
 (0)