|
| 1 | +# inference/admin.py |
| 2 | +from django.contrib import admin |
| 3 | + |
| 4 | +from inference.models.inference import ExternalJob, InferenceJob, ModelVersion |
| 5 | +from inference.models.inference_choice_fields import ( |
| 6 | + ExternalJobStatus, |
| 7 | + InferenceJobStatus, |
| 8 | +) |
| 9 | + |
| 10 | + |
| 11 | +class ExternalJobInline(admin.TabularInline): |
| 12 | + model = ExternalJob |
| 13 | + extra = 0 |
| 14 | + readonly_fields = ["status", "external_job_id", "created_at", "updated_at", "completed_at", "error_message"] |
| 15 | + can_delete = False |
| 16 | + fields = ["external_job_id", "status", "created_at", "updated_at", "completed_at", "error_message"] |
| 17 | + show_change_link = True |
| 18 | + |
| 19 | + |
| 20 | +@admin.register(ModelVersion) |
| 21 | +class ModelVersionAdmin(admin.ModelAdmin): |
| 22 | + list_display = ["api_identifier", "get_classification_type_display", "is_active", "description"] |
| 23 | + list_filter = ["classification_type", "is_active"] |
| 24 | + search_fields = ["api_identifier", "description"] |
| 25 | + actions = ["set_as_active"] |
| 26 | + |
| 27 | + def set_as_active(self, request, queryset): |
| 28 | + for model_version in queryset: |
| 29 | + model_version.set_as_active() |
| 30 | + self.message_user(request, "Selected model versions set as active.") |
| 31 | + |
| 32 | + set_as_active.short_description = "Set selected model versions as active" |
| 33 | + |
| 34 | + |
| 35 | +@admin.register(InferenceJob) |
| 36 | +class InferenceJobAdmin(admin.ModelAdmin): |
| 37 | + list_display = ["id", "collection", "model_version", "status_display", "created_at", "updated_at", "completed_at"] |
| 38 | + list_filter = ["status", "model_version__classification_type"] |
| 39 | + search_fields = ["collection__name", "model_version__api_identifier", "error_message"] |
| 40 | + readonly_fields = ["created_at", "updated_at", "completed_at", "status_display"] |
| 41 | + raw_id_fields = ["collection"] |
| 42 | + fields = [ |
| 43 | + "collection", |
| 44 | + "model_version", |
| 45 | + "status", |
| 46 | + "status_display", |
| 47 | + "error_message", |
| 48 | + "created_at", |
| 49 | + "updated_at", |
| 50 | + "completed_at", |
| 51 | + ] |
| 52 | + inlines = [ExternalJobInline] |
| 53 | + actions = ["initiate_job", "refresh_status", "unload_model"] |
| 54 | + |
| 55 | + def status_display(self, obj): |
| 56 | + return obj.get_status_display() |
| 57 | + |
| 58 | + status_display.short_description = "Status" |
| 59 | + |
| 60 | + def initiate_job(self, request, queryset): |
| 61 | + for job in queryset.filter(status=InferenceJobStatus.QUEUED): |
| 62 | + job.initiate() |
| 63 | + self.message_user(request, "Selected jobs have been initiated.") |
| 64 | + |
| 65 | + initiate_job.short_description = "Initiate selected queued jobs" |
| 66 | + |
| 67 | + def refresh_status(self, request, queryset): |
| 68 | + for job in queryset.filter(status=InferenceJobStatus.PENDING): |
| 69 | + job.refresh_external_jobs_status_and_store_results() |
| 70 | + job.reevaluate_progress_and_update_status() |
| 71 | + self.message_user(request, "Status of selected pending jobs has been refreshed.") |
| 72 | + |
| 73 | + refresh_status.short_description = "Refresh status of selected pending jobs" |
| 74 | + |
| 75 | + def unload_model(self, request, queryset): |
| 76 | + for job in queryset: |
| 77 | + job.unload_model() |
| 78 | + self.message_user(request, "Models for selected jobs have been unloaded.") |
| 79 | + |
| 80 | + unload_model.short_description = "Unload models for selected jobs" |
| 81 | + |
| 82 | + |
| 83 | +@admin.register(ExternalJob) |
| 84 | +class ExternalJobAdmin(admin.ModelAdmin): |
| 85 | + list_display = [ |
| 86 | + "id", |
| 87 | + "inference_job", |
| 88 | + "external_job_id", |
| 89 | + "status_display", |
| 90 | + "created_at", |
| 91 | + "updated_at", |
| 92 | + "completed_at", |
| 93 | + ] |
| 94 | + list_filter = ["status", "inference_job__model_version__classification_type"] |
| 95 | + search_fields = ["external_job_id", "inference_job__collection__name", "error_message"] |
| 96 | + readonly_fields = ["created_at", "updated_at", "completed_at", "status_display"] |
| 97 | + fields = [ |
| 98 | + "inference_job", |
| 99 | + "external_job_id", |
| 100 | + "status", |
| 101 | + "status_display", |
| 102 | + "url_ids", |
| 103 | + "results", |
| 104 | + "error_message", |
| 105 | + "created_at", |
| 106 | + "updated_at", |
| 107 | + "completed_at", |
| 108 | + ] |
| 109 | + actions = ["refresh_status"] |
| 110 | + |
| 111 | + def status_display(self, obj): |
| 112 | + return obj.get_status_display() |
| 113 | + |
| 114 | + status_display.short_description = "Status" |
| 115 | + |
| 116 | + def refresh_status(self, request, queryset): |
| 117 | + ongoing_statuses = [ExternalJobStatus.QUEUED, ExternalJobStatus.PENDING] |
| 118 | + for job in queryset.filter(status__in=ongoing_statuses): |
| 119 | + job.refresh_status_and_store_results() |
| 120 | + job.inference_job.reevaluate_progress_and_update_status() |
| 121 | + self.message_user(request, "Status of selected jobs has been refreshed.") |
| 122 | + |
| 123 | + refresh_status.short_description = "Refresh status of selected jobs" |
0 commit comments