Skip to content

Commit e39b3a1

Browse files
authored
Merge pull request doccano#287 from CatalystCode/bugfix/seq2seq_label_download
Bugfix/seq2seq label download
2 parents 7e2ba5d + 286ea9e commit e39b3a1

File tree

5 files changed

+44
-11
lines changed

5 files changed

+44
-11
lines changed

app/api/managers.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
from collections import Counter
2+
3+
from django.db.models import Manager, Count
4+
5+
6+
class AnnotationManager(Manager):
7+
8+
def get_label_per_data(self, project):
9+
label_count = Counter()
10+
user_count = Counter()
11+
docs = project.documents.all()
12+
annotations = self.filter(document_id__in=docs.all())
13+
14+
for d in annotations.values('label__text', 'user__username').annotate(Count('label'), Count('user')):
15+
label_count[d['label__text']] += d['label__count']
16+
user_count[d['user__username']] += d['user__count']
17+
18+
return label_count, user_count
19+
20+
21+
class Seq2seqAnnotationManager(Manager):
22+
23+
def get_label_per_data(self, project):
24+
label_count = Counter()
25+
user_count = Counter()
26+
docs = project.documents.all()
27+
annotations = self.filter(document_id__in=docs.all())
28+
29+
for d in annotations.values('text', 'user__username').annotate(Count('text'), Count('user')):
30+
label_count[d['text']] += d['text__count']
31+
user_count[d['user__username']] += d['user__count']
32+
33+
return label_count, user_count

app/api/models.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
from django.core.exceptions import ValidationError
88
from polymorphic.models import PolymorphicModel
99

10+
from .managers import AnnotationManager, Seq2seqAnnotationManager
11+
1012
DOCUMENT_CLASSIFICATION = 'DocumentClassification'
1113
SEQUENCE_LABELING = 'SequenceLabeling'
1214
SEQ2SEQ = 'Seq2seq'
@@ -191,6 +193,8 @@ def __str__(self):
191193

192194

193195
class Annotation(models.Model):
196+
objects = AnnotationManager()
197+
194198
prob = models.FloatField(default=0.0)
195199
manual = models.BooleanField(default=False)
196200
user = models.ForeignKey(User, on_delete=models.CASCADE)
@@ -224,6 +228,9 @@ class Meta:
224228

225229

226230
class Seq2seqAnnotation(Annotation):
231+
# Override AnnotationManager for custom functionality
232+
objects = Seq2seqAnnotationManager()
233+
227234
document = models.ForeignKey(Document, related_name='seq2seq_annotations', on_delete=models.CASCADE)
228235
text = models.CharField(max_length=500)
229236

app/api/serializers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,5 +160,5 @@ class Seq2seqAnnotationSerializer(serializers.ModelSerializer):
160160

161161
class Meta:
162162
model = Seq2seqAnnotation
163-
fields = ('id', 'text', 'user', 'document')
163+
fields = ('id', 'text', 'user', 'document', 'prob')
164164
read_only_fields = ('user',)

app/api/utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -373,6 +373,7 @@ def render(self, data, accepted_media_type=None, renderer_context=None):
373373
ensure_ascii=self.ensure_ascii,
374374
allow_nan=not self.strict) + '\n'
375375

376+
376377
class JSONPainter(object):
377378

378379
def paint(self, documents):
@@ -406,6 +407,7 @@ def paint_labels(documents, labels):
406407
data.append(d)
407408
return data
408409

410+
409411
class CSVPainter(JSONPainter):
410412

411413
def paint(self, documents):

app/api/views.py

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
from collections import Counter
2-
31
from django.conf import settings
42
from django.shortcuts import get_object_or_404, redirect
53
from django_filters.rest_framework import DjangoFilterBackend
@@ -85,15 +83,8 @@ def progress(self, project):
8583
return {'total': total, 'remaining': remaining}
8684

8785
def label_per_data(self, project):
88-
label_count = Counter()
89-
user_count = Counter()
9086
annotation_class = project.get_annotation_class()
91-
docs = project.documents.all()
92-
annotations = annotation_class.objects.filter(document_id__in=docs.all())
93-
for d in annotations.values('label__text', 'user__username').annotate(Count('label'), Count('user')):
94-
label_count[d['label__text']] += d['label__count']
95-
user_count[d['user__username']] += d['user__count']
96-
return label_count, user_count
87+
return annotation_class.objects.get_label_per_data(project=project)
9788

9889

9990
class ApproveLabelsAPI(APIView):

0 commit comments

Comments
 (0)