Skip to content

Commit 033fda0

Browse files
committed
feat: add search_studios tool
1 parent 026a223 commit 033fda0

File tree

5 files changed

+285
-0
lines changed

5 files changed

+285
-0
lines changed

demo.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,31 @@ async def demo_search_datasets(client: Client) -> None:
124124
print()
125125

126126

127+
async def demo_search_studios(client: Client) -> None:
128+
"""Demo searching studios."""
129+
tool_name = "search_studios"
130+
print_step_title(tool_name, "🔍 Search studios (keyword='TTS', sort='VisitsCount', limit 3 results)")
131+
132+
result = await client.call_tool(tool_name, {"query": "TTS", "sort": "VisitsCount", "limit": 3})
133+
data = parse_tool_response(result)
134+
135+
if isinstance(data, list) and data:
136+
summaries = []
137+
for studio in data:
138+
name = studio.get("name", "N/A")
139+
chinese_name = studio.get("chinese_name", "N/A")
140+
status = studio.get("status", "N/A")
141+
stars = studio.get("stars", 0)
142+
visits = studio.get("visits", 0)
143+
144+
summaries.append(f"{name} ({chinese_name}) - Status={status}, Stars={stars}, Visits={visits}")
145+
146+
print(f" • Result: Found {len(data)} items - {' | '.join(summaries)}")
147+
else:
148+
print(" • Result: No studios found")
149+
print()
150+
151+
127152
async def demo_search_papers(client: Client) -> None:
128153
"""Demo searching papers."""
129154
tool_name = "search_papers"
@@ -242,6 +267,7 @@ async def main() -> None:
242267
await demo_environment_info(client)
243268
await demo_search_models(client)
244269
await demo_search_datasets(client)
270+
await demo_search_studios(client)
245271
await demo_search_papers(client)
246272
await demo_search_mcp_servers(client)
247273

src/modelscope_mcp_server/server.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from .tools.mcp import register_mcp_tools
1919
from .tools.model import register_model_tools
2020
from .tools.paper import register_paper_tools
21+
from .tools.studio import register_studio_tools
2122
from .utils.metadata import get_server_name_with_version
2223

2324
logger = logging.get_logger(__name__)
@@ -42,6 +43,7 @@ def create_mcp_server() -> FastMCP:
4243
register_context_tools(mcp)
4344
register_model_tools(mcp)
4445
register_dataset_tools(mcp)
46+
register_studio_tools(mcp)
4547
register_paper_tools(mcp)
4648
register_mcp_tools(mcp)
4749
register_aigc_tools(mcp)
Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
"""ModelScope MCP Server Studio tools.
2+
3+
Provides tools for studio-related operations in the ModelScope MCP Server,
4+
such as searching for studios and retrieving studio details.
5+
"""
6+
7+
from typing import Annotated, Literal
8+
9+
from fastmcp import FastMCP
10+
from fastmcp.utilities import logging
11+
from pydantic import Field
12+
13+
from ..client import default_client
14+
from ..settings import settings
15+
from ..types import Studio
16+
17+
logger = logging.get_logger(__name__)
18+
19+
20+
def register_studio_tools(mcp: FastMCP) -> None:
21+
"""Register all studio-related tools with the MCP server.
22+
23+
Args:
24+
mcp (FastMCP): The MCP server instance
25+
26+
"""
27+
28+
@mcp.tool(
29+
annotations={
30+
"title": "Search Studios (创空间 AI 应用)",
31+
}
32+
)
33+
async def search_studios(
34+
query: Annotated[
35+
str,
36+
Field(
37+
description="Keyword to search for related studios. "
38+
"Leave empty to get all studios based on other filters."
39+
),
40+
] = "",
41+
domains: Annotated[
42+
list[Literal["multi-modal", "cv", "nlp", "audio", "AutoML"]] | None,
43+
Field(description="Domain categories to filter by"),
44+
] = None,
45+
sort: Annotated[
46+
Literal["Default", "gmt_modified", "VisitsCount", "StarsCount"],
47+
Field(description="Sort order"),
48+
] = "Default",
49+
limit: Annotated[int, Field(description="Maximum number of studios to return", ge=1, le=30)] = 10,
50+
) -> list[Studio]:
51+
"""Search for studios on ModelScope."""
52+
url = f"{settings.main_domain}/api/v1/dolphin/studios"
53+
54+
# Build criterion for filters
55+
criterion = []
56+
57+
# Add create_type filter (always include all types)
58+
criterion.append(
59+
{
60+
"category": "create_type",
61+
"predicate": "contains",
62+
"values": ["interactive", "programmatic"],
63+
}
64+
)
65+
66+
# Add domains filter
67+
if domains:
68+
criterion.append(
69+
{
70+
"category": "domains",
71+
"predicate": "contains",
72+
"values": domains,
73+
}
74+
)
75+
76+
request_data = {
77+
"Name": query,
78+
"Criterion": criterion,
79+
"SortBy": sort,
80+
"PageNumber": 1,
81+
"PageSize": limit,
82+
}
83+
84+
response = default_client.put(url, json_data=request_data)
85+
86+
studios_data = response.get("Data", {}).get("Studios", [])
87+
88+
studios = []
89+
for studio_data in studios_data:
90+
path = studio_data.get("Path", "")
91+
name = studio_data.get("Name", "")
92+
modelscope_url = f"{settings.main_domain}/studios/{path}/{name}"
93+
94+
if not path or not name:
95+
logger.warning(f"Skipping studio with invalid path or name: {studio_data}")
96+
continue
97+
98+
studio = Studio(
99+
id=str(studio_data.get("Id", "")),
100+
path=path,
101+
name=name,
102+
chinese_name=studio_data.get("ChineseName", ""),
103+
description=studio_data.get("Description", ""),
104+
created_by=studio_data.get("CreatedBy", ""),
105+
license=studio_data.get("License", ""),
106+
modelscope_url=modelscope_url,
107+
independent_url=studio_data.get("IndependentUrl"),
108+
cover_image=studio_data.get("CoverImage"),
109+
type=studio_data.get("Type", ""),
110+
status=studio_data.get("Status", ""),
111+
domains=studio_data.get("Domain") or [],
112+
stars=studio_data.get("Stars", 0),
113+
visits=studio_data.get("Visits", 0),
114+
created_at=studio_data.get("CreatedTime", 0),
115+
updated_at=studio_data.get("LastUpdatedTime", 0),
116+
deployed_at=studio_data.get("DeployedTime", 0),
117+
)
118+
studios.append(studio)
119+
120+
return studios

src/modelscope_mcp_server/types.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,38 @@ class Dataset(BaseModel):
7979
updated_at: Annotated[int, Field(description="Last updated time (unix timestamp, seconds)")] = 0
8080

8181

82+
class Studio(BaseModel):
83+
"""Studio information."""
84+
85+
# Basic information
86+
id: Annotated[str, Field(description="Unique studio ID")]
87+
path: Annotated[str, Field(description="Studio path, for example 'ttwwwaa'")]
88+
name: Annotated[str, Field(description="Studio name, for example 'ChatTTS_Speaker'")]
89+
chinese_name: Annotated[str, Field(description="Chinese name")]
90+
description: Annotated[str, Field(description="Studio description")]
91+
created_by: Annotated[str, Field(description="User who created the studio")]
92+
license: Annotated[str, Field(description="Open source license")]
93+
94+
# Links
95+
modelscope_url: Annotated[str, Field(description="Detail page URL on ModelScope")]
96+
independent_url: Annotated[str | None, Field(description="Independent access URL")] = None
97+
cover_image: Annotated[str | None, Field(description="Cover image URL")] = None
98+
99+
# Classification
100+
type: Annotated[str, Field(description="Studio type, for example 'programmatic' or 'interactive'")]
101+
status: Annotated[str, Field(description="Current status, for example 'Running'")]
102+
domains: Annotated[list[str], Field(description="Domain categories")] = []
103+
104+
# Metrics
105+
stars: Annotated[int, Field(description="Number of stars")] = 0
106+
visits: Annotated[int, Field(description="Number of visits")] = 0
107+
108+
# Timestamps
109+
created_at: Annotated[int, Field(description="Created time (unix timestamp, seconds)")] = 0
110+
updated_at: Annotated[int, Field(description="Last updated time (unix timestamp, seconds)")] = 0
111+
deployed_at: Annotated[int, Field(description="Deployed time (unix timestamp, seconds)")] = 0
112+
113+
82114
class Paper(BaseModel):
83115
"""Paper information."""
84116

tests/tools/test_search_studios.py

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
"""测试 search_studios 工具的功能。"""
2+
3+
import pytest
4+
from fastmcp import Client
5+
6+
7+
# Helper functions
8+
async def search_studios_helper(client, params):
9+
"""Helper function to search studios and validate basic response structure."""
10+
result = await client.call_tool("search_studios", params)
11+
assert hasattr(result, "data"), "Result should have data attribute"
12+
studios = result.data
13+
assert isinstance(studios, list), "Studios should be a list"
14+
return studios
15+
16+
17+
def print_studio_info(studio, extra_fields=None):
18+
"""Print studio information with optional extra fields."""
19+
base_info = (
20+
f"id: {studio.get('id', '')} | "
21+
f"name: {studio.get('name', '')} | "
22+
f"chinese_name: {studio.get('chinese_name', '')} | "
23+
f"type: {studio.get('type', '')}"
24+
)
25+
26+
if extra_fields:
27+
for field in extra_fields:
28+
base_info += f" | {field}: {studio.get(field, 0)}"
29+
30+
print(base_info)
31+
32+
33+
def print_studios_list(studios, description, extra_fields=None):
34+
"""Print a list of studios with description and optional extra fields."""
35+
print(f"✅ Received {len(studios)} studios {description}:")
36+
for studio in studios:
37+
print_studio_info(studio, extra_fields)
38+
39+
40+
def validate_studio_fields(studio):
41+
"""Validate that studio has all required fields."""
42+
required_fields = [
43+
"id",
44+
"path",
45+
"name",
46+
"chinese_name",
47+
"description",
48+
"created_by",
49+
"license",
50+
"modelscope_url",
51+
"type",
52+
"status",
53+
"domains",
54+
"stars",
55+
"visits",
56+
"created_at",
57+
"updated_at",
58+
"deployed_at",
59+
]
60+
61+
for field in required_fields:
62+
assert field in studio, f"Studio should have {field}"
63+
64+
65+
@pytest.mark.integration
66+
async def test_search_studios(mcp_server):
67+
async with Client(mcp_server) as client:
68+
studios = await search_studios_helper(client, {"query": "ChatTTS", "limit": 5})
69+
70+
print_studios_list(studios, "", ["stars"])
71+
72+
assert len(studios) > 0, "Studios should not be empty"
73+
validate_studio_fields(studios[0])
74+
75+
76+
@pytest.mark.integration
77+
async def test_search_studios_with_domain_filter(mcp_server):
78+
async with Client(mcp_server) as client:
79+
studios = await search_studios_helper(client, {"query": "音频", "domains": ["audio"], "limit": 3})
80+
81+
print_studios_list(studios, "with audio domain filter", ["visits"])
82+
83+
# Verify that all returned studios have audio domain
84+
for studio in studios:
85+
assert "audio" in studio.get("domains", []), f"Studio {studio.get('id', '')} should have audio domain"
86+
87+
88+
@pytest.mark.integration
89+
async def test_search_studios_with_domain_filter_cv(mcp_server):
90+
async with Client(mcp_server) as client:
91+
studios = await search_studios_helper(client, {"domains": ["cv"], "limit": 3})
92+
93+
print_studios_list(studios, "with cv domain filter", ["predicts"])
94+
95+
# Verify that all returned studios have cv domain
96+
for studio in studios:
97+
assert "cv" in studio.get("domains", []), f"Studio {studio.get('id', '')} should have cv domain"
98+
99+
100+
@pytest.mark.integration
101+
async def test_search_studios_sort_by_stars(mcp_server):
102+
async with Client(mcp_server) as client:
103+
studios = await search_studios_helper(client, {"query": "生成", "sort": "StarsCount", "limit": 3})
104+
105+
print_studios_list(studios, "sorted by stars", ["stars"])

0 commit comments

Comments
 (0)