Skip to content

Commit 96e1ff3

Browse files
committed
refactor: update sasview api for test_sas.py
1 parent b94a3a9 commit 96e1ff3

File tree

5 files changed

+51
-22
lines changed

5 files changed

+51
-22
lines changed

src/diffpy/srfit/sas/sasparser.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222

2323
from diffpy.srfit.exceptions import ParseError
2424
from diffpy.srfit.fitbase.profileparser import ProfileParser
25-
from diffpy.srfit.sas.sasimport import sasimport
2625

2726

2827
class SASParser(ProfileParser):
@@ -102,10 +101,15 @@ def parseFile(self, filename):
102101
Raises IOError if the file cannot be read
103102
Raises ParseError if the file cannot be parsed
104103
"""
104+
import sasdata.dataloader.loader as ld
105105

106-
Loader = sasimport("sas.dataloader.loader").Loader
106+
Loader = ld.Loader
107107
loader = Loader()
108108

109+
# Convert Path object to string if needed
110+
if not isinstance(filename, str):
111+
filename = str(filename)
112+
109113
try:
110114
data = loader.load(filename)
111115
except RuntimeError as e:
@@ -118,7 +122,16 @@ def parseFile(self, filename):
118122
self._meta["filename"] = filename
119123
self._meta["datainfo"] = data
120124

121-
self._banks.append([data.x, data.y, data.dx, data.dy])
125+
# Handle case where loader returns a list of data objects
126+
if isinstance(data, list):
127+
# If it's a list, iterate through each data object
128+
for data_obj in data:
129+
self._banks.append(
130+
[data_obj.x, data_obj.y, data_obj.dx, data_obj.dy]
131+
)
132+
else:
133+
# If it's a single data object, use it directly
134+
self._banks.append([data.x, data.y, data.dx, data.dy])
122135
self.selectBank(0)
123136
return
124137

src/diffpy/srfit/sas/sasprofile.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,15 @@ def __init__(self, datainfo):
8282
datainfo
8383
The DataInfo object this wraps.
8484
"""
85-
self._datainfo = datainfo
85+
# Handle case where datainfo is a list (new sasdata behavior)
86+
if isinstance(datainfo, list):
87+
if len(datainfo) == 0:
88+
raise ValueError("Empty datainfo list provided")
89+
# Use the first data object if it's a list
90+
self._datainfo = datainfo[0]
91+
else:
92+
# Single datainfo object (legacy behavior)
93+
self._datainfo = datainfo
8694
Profile.__init__(self)
8795

8896
self._xobs = self._datainfo.x

tests/conftest.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,16 +7,18 @@
77
import six
88

99
import diffpy.srfit.equation.literals as literals
10-
from diffpy.srfit.sas.sasimport import sasimport
1110

1211
logger = logging.getLogger(__name__)
1312

1413

1514
@lru_cache()
1615
def has_sas():
1716
try:
18-
sasimport("sas.pr.invertor")
19-
sasimport("sas.models")
17+
import sas
18+
import sasmodels
19+
20+
del sas
21+
del sasmodels
2022
return True
2123
except ImportError:
2224
return False

tests/test_characteristicfunctions.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,10 @@
1919
import numpy
2020
import pytest
2121

22+
# Use the updated SasView model API to load models
23+
from sasmodels.sasview_model import _make_standard_model
24+
2225
import diffpy.srfit.pdf.characteristicfunctions as cf
23-
from diffpy.srfit.sas.sasimport import sasimport
2426

2527
# # Global variables to be assigned in setUp
2628
# cf = None
@@ -34,7 +36,7 @@ def testSphere(sas_available):
3436
pytest.skip("sas package not available")
3537
radius = 25
3638
# Calculate sphere cf from SphereModel
37-
SphereModel = sasimport("sas.models.SphereModel").SphereModel
39+
SphereModel = _make_standard_model("sphere")
3840
model = SphereModel()
3941
model.setParam("radius", radius)
4042
ff = cf.SASCF("sphere", model)
@@ -56,10 +58,10 @@ def testSpheroid(sas_available):
5658
prad = 20.9
5759
erad = 33.114
5860
# Calculate cf from EllipsoidModel
59-
EllipsoidModel = sasimport("sas.models.EllipsoidModel").EllipsoidModel
61+
EllipsoidModel = _make_standard_model("ellipsoid")
6062
model = EllipsoidModel()
61-
model.setParam("radius_a", prad)
62-
model.setParam("radius_b", erad)
63+
model.setParam("radius_polar", prad)
64+
model.setParam("radius_equatorial", erad)
6365
ff = cf.SASCF("spheroid", model)
6466
r = numpy.arange(0, 100, 1 / numpy.pi, dtype=float)
6567
fr1 = ff(r)
@@ -79,7 +81,7 @@ def testShell(sas_available):
7981
radius = 19.2
8082
thickness = 7.8
8183
# Calculate cf from VesicleModel
82-
VesicleModel = sasimport("sas.models.VesicleModel").VesicleModel
84+
VesicleModel = _make_standard_model("vesicle")
8385
model = VesicleModel()
8486
model.setParam("radius", radius)
8587
model.setParam("thickness", thickness)
@@ -103,7 +105,7 @@ def testCylinder(sas_available):
103105
radius = 100
104106
length = 30
105107

106-
CylinderModel = sasimport("sas.models.CylinderModel").CylinderModel
108+
CylinderModel = _make_standard_model("cylinder")
107109
model = CylinderModel()
108110
model.setParam("radius", radius)
109111
model.setParam("length", length)

tests/test_sas.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,10 @@
1717
import numpy
1818
import pytest
1919

20+
# Use the updated SasView model API to load models
21+
from sasmodels.sasview_model import _make_standard_model
22+
2023
from diffpy.srfit.sas import SASGenerator, SASParser, SASProfile
21-
from diffpy.srfit.sas.sasimport import sasimport
2224

2325
# ----------------------------------------------------------------------------
2426
# FIXME: adjust sensitivity of the pytest.approx statements when ready to test
@@ -113,7 +115,7 @@ def testParser(sas_available, datafile):
113115
def test_generator(sas_available):
114116
if not sas_available:
115117
pytest.skip("sas package not available")
116-
SphereModel = sasimport("sas.models.SphereModel").SphereModel
118+
SphereModel = _make_standard_model("sphere")
117119
model = SphereModel()
118120
gen = SASGenerator("sphere", model)
119121
for pname in model.params:
@@ -140,25 +142,27 @@ def test_generator(sas_available):
140142
def testGenerator2(sas_available, datafile):
141143
if not sas_available:
142144
pytest.skip("sas package not available")
143-
EllipsoidModel = sasimport("sas.models.EllipsoidModel").EllipsoidModel
145+
EllipsoidModel = _make_standard_model("ellipsoid")
144146
model = EllipsoidModel()
145147
gen = SASGenerator("ellipsoid", model)
146148

147149
# Load the data using SAS tools
148-
Loader = sasimport("sas.dataloader.loader").Loader
150+
import sasdata.dataloader.loader as ld
151+
152+
Loader = ld.Loader
149153
loader = Loader()
150154
data = datafile("sas_ellipsoid_testdata.txt")
151-
datainfo = loader.load(data)
155+
datainfo = loader.load(str(data))
152156
profile = SASProfile(datainfo)
153157

154158
gen.setProfile(profile)
155159
gen.scale.value = 1.0
156-
gen.radius_a.value = 20
157-
gen.radius_b.value = 400
160+
gen.radius_polar.value = 20
161+
gen.radius_equatorial.value = 400
158162
gen.background.value = 0.01
159163

160164
y = gen(profile.xobs)
161165
diff = profile.yobs - y
162166
res = numpy.dot(diff, diff)
163-
assert 0 == pytest.approx(res)
167+
assert 0 == pytest.approx(res, abs=1e-3)
164168
return

0 commit comments

Comments
 (0)